0x534 Attention
- 1. Cross Attention
- 2. Self-Attention
- 3. Low Rank/Kernel Transformer
- 4. Patterns-based Transformer
- 5. Other Variants
1. Cross Attention
The problem of a standard sequence-based encoder/decoder is we cannot cram all information into a single hidden vector. Instead of a single vector, we can store more information in variable-size vectors.
The traditional cross-attention mechanism allow output sequence to attend input sequence.
Model (Cross attention)
- Use query vector (from decoder state) and key vectors (from encoder state).
- For each query, key pair, we calculate a weight and weights are normalized with softmax.
- Weights are multiplied with a value vector to obtain the target hidden vector.
Weight \(a(q,k)\) can be obtained using different attention functions, for example:
- multilayer perception (Bahdanau 2015): \(a(q, k) = w_2^T \tanh(W[q;k])\)
- bilinear (Luong 2015): \(a(q, k) = q^TWk\)
- dot product (Luong 2015): \(a(q,k) = q^Tk\)
- scaled dot product (Vaswani 2017, attention paper): \(a(q, k) = \frac{q^Tk}{\sqrt{d}}\)
As mentioned in the attention paper, this scaling is to make sure variance is 1 under the assumption \(q=(q_1, ..., q_d), k=(k_1, ..., k_d)\) has independent mean 0 var 1 distribution. Recall the variance of indepedent multiplication from the probability note
2. Self-Attention
Self-attention allows each element in the sequence to attend other elements, to make a more context sensitive encoding. For example, in the translation task, the in English might be translated to le, la, les in French depending on the target noun. An attention from the to the target noun will make it more context sensitive.
Model (self-attention) The self-attention transforms an embedding \((T,H)\) into another embedding \((T,H)\). Shape is invariant under the self-attention transformation
- Input: \((T,H)\)
- Output: \((T,H)\)
- Params: \(W_q, W_k, W_v \in R^{(H,H)}\)
where \(T\) is the sequence length, \(H\) is the hidden size, and \(W_q, W_k, W_v\) are weight matrix.
We first created three matrix \(Q, K, V\), meaning query, key and value.
Next, we query each token to all keys in the sentence and compute their similarity matrix \(S\) normalized by hidden size (scaled dot product)
The \((i,j)\) element is the similarity between token \(i\) and \(j\), each row is normalized by the softmax.
Finally, the similarity is multiplied by the value matrix to get the new embedding \(SV \in R^{(T,H)}\)
Model (multi-head self-attention) Multihead attention use multiple head for self-attention
Incremental decoding with Multihead Attention has low arithmetic density, where memory bandwidth is a bottleneck (see section 2.4.1 of this work)
One approach to speedup inference is Multi-Query Attention where different head share a single set of keys and values (but multiple queries).
Group Query Attention is another variant where queries are segmented into groups
3. Low Rank/Kernel Transformer
4. Patterns-based Transformer
Check this survey
Standard Transformer cannot process long sequences due to the quandratic complexity \(O(T^2)\) of self-attention, both time and memory complexity. For example, BERT can only handle 512 tokens. Instead of the standard global attention, the following models try to circumvent this issue by using restricted attention
Model (sparse transformer) it introduces two sparse self-attention patterns:
- stride pattern: capture periodic structure
- fixed pattern: specific cells summarize previous locations and propagate them into all future cells
Model (Longformer) Longformer’s attention mechanism is a combination of the following two
- windowed local-context self-attention
- dilated sliding window
- an end task motivated global attention that encodes inductive bias about the task.
Model (BigBird) use global token (e.g: CLS) and sparse attention
5. Other Variants
Model (Primer) has smaller training cost than the original for autoregressive LM.