steven liu

Attention please

Aug 10, 2025 (5 months ago)36 views

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

It's amazing that an idea so simple has such a tremendous and lasting impact. This idea is described in only two sentences, but it is consumed by millions of people using large language models daily.

This post tries to very simply and very plainly explain how attention, and variants of it, work.

self-attention

Self-attention (scaled dot-product attention) computes how much attention each word in a sequence assigns to every other word in that same sequence. It's what makes transformers contextually aware.

Q×KTscoressoftmaxweights×V÷ √dkQ × KT produces attention scores, scaled and softmaxed into weights, then multiplied by V.

You need 3 matrices to compute self-attention. These matrices are created by multiplying word embeddings by 3 weight matrices (Wk, Wq, Wv).

Try calculating the self-attention score by hand in the following sequence to really get a feel for how it works.

causal self-attention

Causal self-attention is used in decoder models like GPT. Decoder models predict the next word in a sequence, so it is important to mask the next words to prevent the model from seeing them.

single headQKVcausal maskQ, K, V per head. Future positions masked (gray) so tokens only attend to past.

The mask sets words that should be blocked to -∞, which is converted to ~0 by the softmax.

Try calculating the causal self-attention score by hand in the following sequence.

multi-head attention

Multi-head attention (MHA) adds more heads to learn from different perspectives.

head 1Q₁K₁V₁head 2Q₂K₂V₂head 3Q₃K₃V₃Each head has its own Q, K, and V projections.

The calculations are the same as self-attention, but the embedding size is split by the number of heads. Each head independently computes the scaled dot-product attention on their slice of data. At the end, the outputs are combined and multiplied by a weight matrix to blend all the information together.

BERT uses MHA.

Try calculating the multi-head attention score by hand in the following sequence.

multi-query attention

Multi-query attention (MQA) is the same as MHA except every Q shares the same K, V.

head 1Q₁head 2Q₂head 3Q₃head 4Q₄head 5Q₅head 6Q₆KVEach head has its own Q. All heads share a single K and V.

MQA is more memory-efficient and faster at decoding because each head doesn't need to store a separate K, V. This makes an especially big difference for really long sequences.

Gemma 2B uses MQA.

Try calculating the multi-query attention score by hand in the following sequence.

grouped-query attention

Grouped-query attention (GQA) is similar to MHA and MQA except groups of Q share the same K, V. K, V is different for each group.

group 1group 2head 1Q₁head 2Q₂head 3Q₃head 4Q₄K₁V₁head 5Q₅head 6Q₆head 7Q₇head 8Q₈K₂V₂Each head has its own Q. K and V are shared within each group.

GQA is the middle ground between MHA and MQA. It's faster than MHA and more expressive than MQA.

gpt-oss-20b uses GQA.

Try calculating the grouped-query attention score by hand in the following sequence.

flashattention

FlashAttention changes how the attention calculation is executed to minimize memory traffic to and from the GPU's slower high-bandwidth global memory (HBM). The calculation is done on the faster on-chip shared memory (SRAM) and registers, without fully materializing the entire attention score matrix in HBM.

block 1Q×K₁Tm₁l₁local statsblock 2Q×K₂Tm₂l₂updated× V₁outputO₁× V₂O₂Block 2 uses m₁, l₁ from block 1 to compute new stats m₂, l₂ and rescale l₁ before accumulating.

Attention scores are streamed over blocks to avoid storing the entire attention score matrix in memory. The softmax calculation is exactly the same. Only Q, K, V, the final output, and the softmax normalization factor from the forward pass are stored in HBM.

The backward pass uses the same idea to avoid storing the attention probabilities P in memory. P is required for computing gradients with respect to Q, K, and V. FlashAttention recomputes P block-by-block, uses the same softmax statistics from the forward pass (m, l), and accumulates the gradients.

Speed ups increase with sequence length. But for short sequences, attention is less memory-bound and there's less benefit.

flashattention-2

FlashAttention-2 improves FlashAttention by reducing non-matmul FLOPs and scheduling more work on the GPU.

FlashAttentionbatch × headsbatch 0head 0batch 0head 1batch 0head 2batch 0head 3batch 1head 0batch 1head 1batch 1head 2batch 1head 3idleidleidleidleidleidleidleidle8 of 16 SMs usedFlashAttention-2batch × heads × seq blocksb0 h0seq 0b0 h0seq 1b0 h1seq 0b0 h1seq 1b0 h2seq 0b0 h2seq 1b0 h3seq 0b0 h3seq 1b1 h0seq 0b1 h0seq 1b1 h1seq 0b1 h1seq 1b1 h2seq 0b1 h2seq 1b1 h3seq 0b1 h3seq 116 of 16 SMs usedFlashAttention-2 splits work across sequence blocks, using more SMs when batch size and heads are small.
FlashAttentionsplit K, V across warpsQ (shared)warp 0Q × K₀,V₀warp 1Q × K₁,V₁warp 2Q × K₂,V₂shared memory (write partial results)sync + reduceFlashAttention-2split Q across warpsK, V (shared)warp 0Q₀ × K,Vwarp 1Q₁ × K,Vwarp 2Q₂ × K,VO₀ (no sync)O₁ (no sync)O₂ (no sync)no shared memory writesFlashAttention-2 assigns Q blocks to warps instead of K,V, reducing shared memory writes and synchronization.

flashattention-3

FlashAttention-3 optimizes for newer hardware, like H100 GPUs, that support asynchrony. Specialized compute units run in parallel. Tensor cores can do fast matmuls, Tensor Memory Accelerator (TMA) loads data from HBM, and CUDA cores can perform slower computations like softmax.

FlashAttention-2all warps do same worktime →warp 0load K₁V₁computeload K₂V₂warp 1load K₁V₁computeload K₂V₂warp 2load K₁V₁computeload K₂V₂⚠ sequential: load → compute → loadFlashAttention-3specialized producer/consumertime →producerload K₁V₁load K₂V₂load K₃V₃TMAconsumerwaitattn K₁V₁attn K₂V₂tensor cores✓ overlapped: load while computingFlashAttention-3 overlaps memory transfers and computation using specialized producer/consumer warps.
step 1step 2step 3step 4WGMMAS_next = Q @ Kⱼasync (no block)O += P @ Vⱼ₋₁async accumulatewgmma_waitsync pointoverlapsoftmax(S_next issuing)m = max(m, rowmax(S))P = exp(S - m)(O accumulating + wait)WGMMA issues async matmuls, allowing softmax to run in parallel on the previous block's scores.
global scaleper-block scale0.50.30.20.452.047.0scale ≈ 0.116block 1[0.5, 0.3, 0.2, 0.4]448268179357scale₁ ≈ 0.001block 2[47, 52, 41, 48]405448353414scale₂ ≈ 0.116Each block computes its own scale so precision is preserved locally. Both use the full FP8 range.

FlashAttention-3 uses a different scale per block. Each block can use the full FP8 range for its values. The outlier block gets a big scale and the normal blocks get a tiny scale. This way, precision is preserved locally.

input0.50.30.247.00.1×random orthogonal RRTR = I (preserves dot products)=output8.29.17.88.58.4Multiplying by random orthogonal matrix R redistributes the outlier across all dimensions while preserving dot products.

If an outlier is in a normal block, FlashAttention-3 fixes this by spreading the outlier across all dimensions before quantization. It multiplies Q, K by a random orthogonal matrix to redistribute the outlier. This way, no single dimension dominates a block.

summary

Self-attention is a weighted sum of contextual information of all words. The weights reflect each surrounding word's relevance to the current word.

But attention is memory-bound. New attention algorithms optimize how the attention calculation is executed, whether it's sharing the same K, V or streaming the attention function.

resources