Deep Learning 101: Lesson 30: Understanding Text with Attention Heatmaps

Muneeb S. Ahmad
9 min readSep 4, 2024

--

This article is part of the “Deep Learning 101” series. Explore the full series for more insights and in-depth learning here.

Multi-head attention is a powerful concept that enhances the attention mechanism in natural language processing by allowing the model to focus on different aspects of a sentence simultaneously, creating comprehensive representations that capture different perspectives. The Attention Heatmap provides an intuitive graphical visualization of these attention scores, allowing users to see which words receive the most emphasis across different attention heads.

In Masked Language Model (MLM) tasks within BERT, these concepts work together to accurately predict missing words based on surrounding context, revealing how the model understands relationships and meaning in text. This collaboration ensures that BERT excels at understanding language nuances, helping researchers refine training strategies for improved language comprehension.

Figure 1: Input Text

Consider this simple toy example to model and understand a small number of sentences. In Figure 1, we have a collection of six sentence pairs (or sequences) which will be used in training a Masked Language Model (MLM). In MLM, certain words in a text sequence are intentionally hidden (masked), and the model’s goal is to predict these missing words based on the surrounding context. This training approach enables the model to understand the meaning and relationships between words, leading to better comprehension of the language. We use sentence pairs because they enable the model in predicting the sentence flow and sequence.

This example showcases a small but meaningful set of sentence pairs. Each sentence pair provides context that helps the model understand how different elements of a sentence relate to one another. For instance, the sentence pair ‘The cat chased the mouse. It was a fast chase.’ gives the model context about the action and its attribute. Similarly, ‘I need to buy milk. Please go to the store.’ provides an understanding of the necessity and the action taken to fulfill that need.

By training on such sentence pairs, the MLM can learn to predict missing words in similar contexts, enhancing its ability to understand and generate human-like text.

In Figure 2, you see the sentences with certain words replaced by [MASK] tokens. This masking process is crucial for training the MLM. Masked input text involves replacing some words in the sentences with [MASK] tokens. The goal of the MLM is to predict these masked words based on the context provided by the surrounding words.

Figure 2: Masked input text

The purpose of masking is to force the model to learn to understand the context of words. By predicting the masked words, the model learns the relationships and dependencies between different parts of the sentence, which enhances its language comprehension and generation capabilities. Below is how Masked Input Text is generated.

  • Masking Tokens: Specific words in the input sentences are replaced with the [MASK] token. For instance, in the sentence ‘The cat chased the mouse. It was a fast chase.’, the words ‘cat’ and ‘chase’ are masked as ‘[MASK]’. This masking forces the model to rely on the surrounding context to predict the masked words, thereby strengthening its understanding of how words interact within a sentence.
  • Special Tokens: To further aid the model in understanding the structure of the input, special tokens are added. The [CLS] token is added at the beginning of each sequence to denote the start of the sequence, while [SEP] tokens are used to separate different segments within the sequence. Additionally, [PAD] tokens are added to ensure that all sequences are of the same length, which is necessary for efficient batch processing during training. These special tokens help the model differentiate between different parts of the input and maintain consistency in sequence length.
  • Contextual Learning: The model uses the context provided by the unmasked words to predict the masked words. For example, in ‘The [MASK] chased the mouse [SEP] It was a fast [MASK]’, the model uses the context to predict that ‘[MASK]’ should be ‘cat’ and ‘chase’. This ability to infer missing information based on context is crucial for tasks such as language translation, text completion, and sentiment analysis. By repeatedly training on masked input text, the model gradually improves its ability to generate accurate predictions, leading to better performance on a wide range of natural language processing tasks.

Attention Score Calculation

Figure 3: Attention Score Calculation

The attention score calculation and heatmap shown in this tool are derived from the last attention layer of a simple BERT model trained on the given input sequences. This visualization reflects the final attention distributions that the model uses to predict masked words, demonstrating the learned context and relationships from all preceding layers. For a more detailed end-to-end demonstration, you can refer to another tool on this portal called “BERT Explorer”.

The “Attention Score Calculation” section visually represents how input data (here, “Sample 0”) is processed to compute attention scores within the multi-head attention mechanism.

Input (X):

  • The green box on the left shows the input sequence as a matrix of numerical embeddings, representing word embeddings for all 14 words in the sentence with an embedding dimension of 7.
  • Note: Only a subset of each tensor or matrix is displayed to avoid overwhelming the diagram. For complete matrix values, you can refer to the “BERT Explorer” tool.

Weight Matrices:

  • The black matrices transform the input embedding matrix X into queries (Q), keys (K), and values (V).
  • These weights are adjusted during training to accurately predict the masked words. Specifically, WQ transforms X into Q (queries), WK transforms X into K (keys), and WV transforms X into V (values).

Query (Q), Key (K), and Value (V) Matrices:

  • The input matrix X is multiplied by the weight matrices WQ, WK, and WV to produce the Q, K, and V matrices.
  • Dimensions: Q, K, and V have dimensions [14, 20], where 20 is derived from the transformation process (e.g., the product of the input dimension and the weight matrix dimension).

Splitting Q, K, V into Multiple Heads:

  • Q, K, and V matrices are split into multiple heads to capture different aspects of the input dat
  • In this example, Q, K, and V are split into 5 heads (Q0 to Q4, K0 to K4, V0 to V4), each having dimensions [14, 4], where 4 represents the dimensionality of each head.

Attention Calculation:

  • Each attention head calculates attention scores by taking the dot product between queries and keys and normalizing using the softmax function, represented by the formula:

Understanding the Attention Heatmap

The attention heatmap demonstrates the inner workings of multi-head attention by displaying attention scores in a grid format. These scores show how much focus each word in a sentence receives from every other word, which helps the model understand and generate context-aware representations.

Figure 4: Attention Heatmap

In the context of understanding how multi-head attention works within a Masked Language Model (MLM), it is crucial to visualize how attention scores are distributed across different words in a sentence. We can call this visualization an attention heatmap and an example of which is shown in Figure 4.

Different input text samples can be used to observe how attention scores vary for each word in different sentences. For example, the sample heatmap is calculated for the input sample “the cat chased the mouse. it was a fast chase”. The attention patterns also vary between epochs as these attention patterns evolve. For example at epoch 3750 of 5000 shown in Figure 5), we see the intermediate stage of learning and how attention distributions are captured at this stage of model training.

There is another aspect of attention scores which is its calculation across multiple heads. Each head focuses on different aspects of the input data, offering unique perspectives on word relationships. For example, Head 3 word relationships are captured in Figures 4 and 5.

In the attention heatmap, the rows represent words or tokens in the input sequence, while columns show the words they attend to. Each cell in the grid displays a normalized attention score between 0 and 1, with higher scores highlighted in brighter colors. The darker cells imply lesser attention focus. For instance, in the sentence “the cat chased the mouse. it was a fast chase”, head 3 shows strong attention between “cat” and “chased”, and “fast” and “chased”, indicating the head’s focus on subject-verb-adverb relationships.

Although not shown in the diagram, the outputs from each head are concatenated to form a single matrix, which is then used to compute the final output of the attention mechanism. This combined matrix retains diverse information from all heads, capturing various relationships within the input data.

This mechanism of computing attention scores and generating context-aware word embeddings is a fundamental component of the encoder. An encoder in typical NLP models consists of a multi-head attention mechanism followed by a feed-forward neural network. By stacking multiple encoders, the model can build complex representations of the input data through successive layers, capturing intricate patterns and relationships, leading to more accurate predictions of masked words.

The output from the final encoder layer, which contains rich, context-aware representations of the input tokens, is then fed into a prediction layer. This layer uses the learned context and relationships from all preceding encoders to accurately predict the masked words. This detailed process of prediction, including the prediction layer, is shown in another tool called “BERT Explorer” for a more comprehensive understanding of the BERT model.

Figure 5: Attention Heatmap

After understanding the attention mechanisms and their role in predicting masked words, it’s time to see the results. The final output prediction (Figure 6) showcases the model’s ability to accurately fill in the blanks.

Figure 6: Prediction of Masked Words

This output displays the sentence after the model predicts the masked words as it is fully trained. Initially, the masked words in the input sentence are replaced with [MASK] tokens. The model then predicts the most likely words to replace these tokens, reconstructing the sentence. This prediction process highlights the model’s ability to comprehend and generate contextually appropriate words, demonstrating the efficacy of multi-head attention and the overall BERT model architecture. By analyzing these predictions at different stages of training, one can gain insights into the model’s learning progress and its understanding of language.

Summary

The exploration of the Attention Heatmap tool reveals how multi-head attention mechanisms within models like BERT help in understanding the intricate relationships between words in a text. By visualizing attention scores, the heatmap demonstrates how different words focus on each other, providing insights into the model’s interpretation of the text. The tool shows how attention scores are distributed across multiple heads and epochs, highlighting how the model’s focus shifts and refines over time. This visualization aids in comprehending how BERT captures various linguistic patterns and dependencies, enhancing its ability to understand and generate human language accurately.

4 Ways to Learn

1. Read the article: Attention Heatmaps

2. Play with the visual tool: Attention Heatmaps

Play with the visual tool: Attention Heatmaps

3. Watch the video: Attention Heatmaps

4. Practice with the code: Attention Heatmaps

Previous Article: Attention Scores in NLP
Next Article: Exploring BERT

--

--

Muneeb S. Ahmad

Muneeb Ahmad is a Senior Microservices Architect and Recognized Educator at IBM. He is pursuing passion in ABC (AI, Blockchain, and Cloud)