by Mike Lawrence
Goal: Efficiently capture state of the input sequence for downstream modeling.
Examples:
- Language Generation and Translation (Sequence of word tokens)
- Time Series Forecasting (Sequence of temporal observations)
- Computer Vision (Sequence of image patches)
The famous Attention Is All You Need
paper from Google opened the door for GPU-efficient, high-performing sequence modeling via the Transformer architecture: https://arxiv.org/pdf/1706.03762.pdf
Previous state-of-the-art models, like RNNs and LSTMs, relied on sequential processing of the input sequence (not parallel like attention), had noisy compression of the input state, and high numeric instability. Roughly speaking, Attention performs a weighted transformation to the input sequence by matrix multiplication, while RNN-style architectures do a for-loop over the input sequence and update a shared state each iteration.
Pictured: LSTM architecture (applied iteratively to the sequence)
Pictured: Transformer architecture (encode in bulk, decode iteratively)
This change is fundamentally possible due to the attention
mechanism.
Convert a sequence of embeddings into a new sequence of the same length where each converted embedding is a "context vector", containing information about the entire sequence.
Diagram
Each h
in the stack of attention layers is a "head", thus multi-headed attention. Implementation of Multi-head Attention (MHA) thus only requires building attention once and creating a stack of such layers.
Here's a breakdown of the linear algebra operations involved in building a single attention layer, where Q
, K
, V
are Queries, Keys, and Values per the original information retrieval context it was written in:
Pictured: The sequence of length T
and embedding dimension d
are transformed into a new sequence of shape
The entire process can be distilled to the equation:
Where the attention matrix
and the entire quantity, C
, are the context vectors.
First, the inputs will be the ordered embeddings for the task to transform, such as:
vocab | dim1 | dim2 | dim3 |
---|---|---|---|
See | 0.2 | 0.4 | 0.3 |
that | 0.8 | 0.22 | 0.3 |
cat | 0.3 | 0.65 | 0.11 |
This
Note: The embeddings come from elsewhere. The quality of data representation is completely external to the attention mechanism, though they could be trained together (see pretrain-finetune paradigm on Google).
Takeaway: We call this
The
Recall we have the input sequence,
Where the sequence length is
For each of
These are all compatible for multiplication with the input sequence,
We then create the final
where each row, item
Takeaway:
Let's take a closer look at the "Attention Matrix",
Recall that
In the attention matrix calculation,
Each matrix element is a product of the form,
Recall the shapes of
Thus intuitively each element, as a dot product, may be thought of as a similarity measure between the
Takeaway: a token at time t
can attend
to tokens at other parts of the sequence. Attending before t
looks into the past, while after t
is the future of the sequence.
It's common to mask attention values to enforce causality or other constraints in the model. For example, when decoding a new sentence, it is not appropriate to allow a token at time t
to attend to tokens at future times!
You can introduce a mask matrix to null out values that should not be attended to -- this makes the attention weights 0.
Thus the attention calculation may look something like,
where
Takeaway: You can modify attention with a mask to exclude items in the sequence, such as for preserving causality.
Consider the sequence ['attention', 'is', 'cool']
. Each has an associated embedding in the input sequence, which get turned into queries and keys in the matrix multiplication step, Q = [query0, query1, query2]
and K = [key0, key1, key2]
.
The attention matrix,
In the decoder case similarities with future values in the sequence aren't known and thus should be masked, making a lower triangular matrix. The K and Q for each token are thus compared only to the time steps at or before themselves.
The final attention matrix is often then visualized for feature importance,
Pictured: masked self-attention matrix (unnormalized)
We stepped through
Applying the
Taking the earlier masked matrix as an example,
The softmax of
Considering again the the sequence this represents, ['attention', 'is', 'cool']
, we interpret these values row-by-row as what that token "attends" to:
- Timestep 0:
attention
->attention
(1.0) - Timestep 1:
is
->attention
(0.6),is
(0.4) - Timestep 2:
cool
->attention
(0.2),is
(0.5),cool
(0.3)
Takeaway: for each time step we learn to weigh previous tokens, including their position information. These weights can be thought of as how much each token in
We built up the pieces to form the context vectors,
We know the final matrix should be of shape
Thus we multiply
The matrix elements can be interpreted like,
Where Values
column across time
Values
.
Using the old example,
The first matrix entry of the context vectors is the standard matrix multiplication of these two - the first row of the weights dotted with the first column of the Values.
Note that in this masked case, the first row of the Values is actually unchanged - the weighted average for the first timestep has no historical values to factor in!
Takeaway: each new embedding is a weighted average of itself and all other embeddings, where the weights are the attention values!
Everything we worked through was an example of self-attention -- the
Another important type of attention is cross-attention, where we use two input sequences instead of one. What does this look like?
Imagine we are doing a translation task. In this instance there is an encoded input sentence, inputs 2
, which is used to generates inputs 1
contains the decoder's sequence. The resulting attention matrix then calculates attention values between the items in the decoder, or target language, and the encoder, or source language!
Perhaps most miraculous of all is that the sequence representations in both inputs can be from different modalities.
Pictured: Cross attention weights between image patches and language tokens. The heatmap for some words, like furry or bike, produce significant weight in the image patch attention weights!
Takeaway: Self-attention transforms data in the context of itself. Cross-attention transforms data in the context of outside information.
Run attention on the same input data N
times in parallel. Each of the N
times is on a separate attention instance with its own weights. Concat the results.
That's it!
The entire transformer can be decomposed into a series of components that are easily built with knowledge we have now:
The left "tower" in this famous image is an encoder, while the right is a decoder.
The encoder adds positional information to each token's embedding then runs multi-headed self-attention on the sequence without masking to generate a risk represenation of the encoded sequence. Residual information via skip layers and a MLP with layer normalization perform some basic processing of the transformed embeddings and offer some additional modeling power.
Takeaway: Forms a rich representation of context data via self-attention.
The decoder's role is to generate an output, such as the next word in a sequence. It does this auto-regressively -- each generated output is appended to the input sequence before running again. The input sequence has positional information inserted into it, and causal self-attention is used to enrich the input sequence, then context information from the encoder is attended to in a cross-attention step. Vanilla feed-forward and layer norm steps, followed by a softmax, give a probability over the vocabulary.
Takeaway: Uses encoded context data (if available) to generate new outputs via cross-attention.
TODO: Add sources for all borrow content (images are not mine!)