steven liu

Attention please

Aug 10, 2025 (5 months ago)41 views

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

It's amazing that an idea so simple has such a tremendous and lasting impact. This idea is described in only two sentences, but it is consumed by millions of people using large language models daily.

This post tries to very simply and very plainly explain how attention, and variants of it, work.

self-attention

Self-attention (scaled dot-product attention) computes how much attention each word in a sequence assigns to every other word in that same sequence. It's what makes transformers contextually aware.

Q×KTscoressoftmaxweights×V÷ √dkQ × KT produces attention scores, scaled and softmaxed into weights, then multiplied by V.

You need 3 matrices to compute self-attention. These matrices are created by multiplying word embeddings by 3 weight matrices (Wk, Wq, Wv).

Try calculating the self-attention score by hand in the following sequence to really get a feel for how it works.

causal self-attention

Causal self-attention is used in decoder models like GPT. Decoder models predict the next word in a sequence, so it is important to mask the next words to prevent the model from seeing them.

single headQKVcausal maskQ, K, V per head. Future positions masked (gray) so tokens only attend to past.

The mask sets words that should be blocked to -∞, which is converted to ~0 by the softmax.

Try calculating the causal self-attention score by hand in the following sequence.

multi-head attention

Multi-head attention (MHA) adds more heads to learn from different perspectives.

head 1Q₁K₁V₁head 2Q₂K₂V₂head 3Q₃K₃V₃Each head has its own Q, K, and V projections.

The calculations are the same as self-attention, but the embedding size is split by the number of heads. Each head independently computes the scaled dot-product attention on their slice of data. At the end, the outputs are combined and multiplied by a weight matrix to blend all the information together.

Try calculating the multi-head attention score by hand in the following sequence.

multi-query attention

Multi-query attention (MQA) is the same as MHA except every Q shares the same K, V.

head 1Q₁head 2Q₂head 3Q₃head 4Q₄head 5Q₅head 6Q₆KVEach head has its own Q. All heads share a single K and V.

MQA is more memory-efficient and faster at decoding because each head doesn't need to store a separate K, V. This makes an especially big difference for really long sequences.

Try calculating the multi-query attention score by hand in the following sequence.

grouped-query attention

Grouped-query attention (GQA) is similar to MHA and MQA except groups of Q share the same K, V. K, V is different for each group.

group 1group 2head 1Q₁head 2Q₂head 3Q₃head 4Q₄K₁V₁head 5Q₅head 6Q₆head 7Q₇head 8Q₈K₂V₂Each head has its own Q. K and V are shared within each group.

GQA is the middle ground between MHA and MQA. It's faster than MHA and more expressive than MQA.

Try calculating the grouped-query attention score by hand in the following sequence.

multi-head latent attention

standardKVKVKVKVMLAmemory savedMLA compresses KV into small latents.

Multi-head latent attention (MLA) stores a compressed version of KV (or latent) to save more space in the cache. At inference, the latent Q and cached latents KV are multiplied by Wcombined to compute attention.

inference(expands)cQ×WUQQcKV×WUKKMLA(stays small)cQ×WUQWUK×cKVWcombined(precomputed)Wcombined absorbs WUQ and WUK.

The usual approach expands the latent KV back to its full size with an up-projection matrix WUK before computing attention. This takes up more memory and negates the compression benefit.

MLA skips this step by "absorbing" or pre-multiplying WUQ and WUK into a single matrix, Wcombined, ahead of time. The latents don't need to be expanded during inference.

with RoPEWUQ×RtTRj×WUKRtTRjchanges per position pairt=0, j=1t=0, j=2t=1, j=2RtTRj sits between WUQ and WUK and changes per token pair.

Rearranging the maths like this makes it incompatible with rotary position embeddings (RoPE) though. RoPE adds a position-dependent rotation (RtTRj) to Q and K that rotates them by their token positions (t,j).

The term, RtTRj, depends on the (t,j) pair. There is no single fixed matrix you can precompute that works for all (t,j) pairs, so it can't be folded into Wcombined. Using RoPE this way would require recomputing the full K from the latent during generation (increases latency) or caching the full RoPE-applied keys (increases memory).

Attention++qCkCqRkRcachecQcKVcacheh×Wcombined×RoPERoPEcontentpositionContent and position are decoupled; cKV and kR are cached.

Decoupling the positional information from the content solves this. The content calculation is the same, but the positional information is applied in a separate stream.

kR is cached, and the content and positional scores are summed when computing attention. This way, positional information flows through small decoupled vectors and content information can still flow through the latents.

flashattention

FlashAttention changes how the attention calculation is executed to minimize memory traffic to and from the GPU's slower high-bandwidth global memory (HBM). The calculation is done on the faster on-chip shared memory (SRAM) and registers, without fully materializing the entire attention score matrix in HBM.

block 1Q×K₁Tm₁l₁row-wise statsblock 2Q×K₂Tm₂l₂row-wise update× V₁outputO₁× V₂O₂Block 2 uses m₁, l₁ from block 1 to update row-wise softmax stats and rescale before accumulating.

Attention scores are streamed over blocks to avoid storing the entire attention score matrix in memory. The softmax calculation is exactly the same. Only Q, K, V, the final output, and the softmax normalization factor from the forward pass are stored in HBM.

The backward pass uses the same idea to avoid storing the attention probabilities P in memory. P is required for computing gradients with respect to Q, K, and V. FlashAttention recomputes P block-by-block, uses the same softmax statistics from the forward pass (m, l), and accumulates the gradients.

Speed ups increase with sequence length. But for short sequences, attention is less memory-bound and there's less benefit.

flashattention-2

FlashAttention-2 improves FlashAttention by reducing non-matmul FLOPs and scheduling more work on the GPU.

FlashAttentionbatch × heads × query blocksbatch 0h0 q0batch 0h1 q0batch 0h2 q0batch 0h3 q0batch 1h0 q0batch 1h1 q0batch 1h2 q0batch 1h3 q0idleidleidleidleidleidleidleidle8 of 16 SMs usedFlashAttention-2batch × heads × query blocksb0 h0q 0b0 h0q 1b0 h1q 0b0 h1q 1b0 h2q 0b0 h2q 1b0 h3q 0b0 h3q 1b1 h0q 0b1 h0q 1b1 h1q 0b1 h1q 1b1 h2q 0b1 h2q 1b1 h3q 0b1 h3q 116 of 16 SMs usedFlashAttention-2 schedules more query blocks in parallel, which improves SM usage for long sequences.
FlashAttentionsplit K, V across warpsQ (shared)warp 0Q × K₀,V₀warp 1Q × K₁,V₁warp 2Q × K₂,V₂shared memory (write partial results)sync + reduceFlashAttention-2split Q across warpsK, V (shared)warp 0Q₀ × K,Vwarp 1Q₁ × K,Vwarp 2Q₂ × K,VO₀ (no sync)O₁ (no sync)O₂ (no sync)no shared memory writesFlashAttention-2 assigns Q blocks to warps instead of K,V, reducing shared memory writes and synchronization.

flashattention-3

FlashAttention-3 optimizes for newer hardware, like H100 GPUs, that support asynchrony. Specialized compute units run in parallel. Tensor cores can do fast matmuls, Tensor Memory Accelerator (TMA) loads data from HBM, and CUDA cores can perform slower computations like softmax.

FlashAttention-2all warps do same worktime →warp 0load K₁V₁computeload K₂V₂warp 1load K₁V₁computeload K₂V₂warp 2load K₁V₁computeload K₂V₂⚠ sequential: load → compute → loadFlashAttention-3specialized producer/consumertime →producerload K₁V₁load K₂V₂load K₃V₃TMAconsumerwaitattn K₁V₁attn K₂V₂tensor cores✓ overlapped: load while computingFlashAttention-3 overlaps memory transfers and computation using specialized producer/consumer warps.
step 1step 2step 3step 4WGMMAS_next = Q @ Kⱼasync (no block)O += P @ Vⱼ₋₁async accumulatewgmma_waitsync pointoverlapsoftmax(S_next issuing)m = max(m, rowmax(S))P = exp(S - m)(O accumulating + wait)WGMMA issues async matmuls, allowing softmax to run in parallel on the previous block's scores.
global scaleper-block scale0.50.30.20.452.047.0scale ≈ 0.116block 1[0.5, 0.3, 0.2, 0.4]448268179357scale₁ ≈ 0.001block 2[47, 52, 41, 48]405448353414scale₂ ≈ 0.116Each block computes its own scale so precision is preserved locally. Both use the full FP8 range.

FlashAttention-3 uses a different scale per block. Each block can use the full FP8 range for its values. The outlier block gets a big scale and the normal blocks get a tiny scale. This way, precision is preserved locally.

input0.50.30.247.00.1×random orthogonal RRTR = I (preserves dot products)=output8.29.17.88.58.4Multiplying by random orthogonal matrix R redistributes the outlier across all dimensions while preserving dot products.

If an outlier is in a normal block, FlashAttention-3 fixes this by spreading the outlier across all dimensions before quantization. It multiplies Q, K by a random orthogonal matrix to redistribute the outlier. This way, no single dimension dominates a block.

summary

Self-attention is a weighted sum of contextual information of all words. The weights reflect each surrounding word's relevance to the current word.

But attention is memory-bound. New attention algorithms optimize how the attention calculation is executed, whether it's sharing the same K, V or streaming the attention function.

resources