Skip to content
Hugging Face
Open Source Team
Developer Docs
North Bay Area
California
@stevhliu
Aug 10, 2025
0 views

Steven Liu

Attention please

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

It's amazing that an idea so simple has such a tremendous and lasting impact. This idea is described in only two sentences, but it is consumed by millions of people using large language models daily.

This post tries to very simply and very plainly explain how attention, and variants of it, work.

#self-attention

Self-attention (scaled dot-product attention) computes how much attention each word in a sequence assigns to every other word in that same sequence. It's what makes transformers contextually aware.

Q×KTscoressoftmaxweights×V÷ √dkQ × KT produces attention scores, scaled and softmaxed into weights, then multiplied by V.

You need 3 matrices to compute self-attention. These matrices are created by multiplying word embeddings by 3 weight matrices (Wk, Wq, Wv).

  • Query (Q) is the information you're looking for. It is compared to every K (including itself) to calculate how much attention to pay to each word.

  • Key (K) is the information a word contains. It is multiplied (dot product) by Q to get the attention scores for each word.

    The scores are scaled by dividing by the square root of the vector dimension. For example, if the dimension is 64, divide the attention scores by 8. Scaling prevents the attention scores from becoming too large or too small.

    Scores are converted into probabilities that add up to 1 with the softmax function. Scaling also smooths out the softmax function by preventing any one value from getting too large and overwhelming the rest.

  • Value (V) is the information each word contributes. It weights each word with the attention score to determine what information every other word offers.

Try calculating the self-attention score by hand in the following sequence to really get a feel for how it works.

WordQKV
Fear[0.2, 0.1, 0.8][0.3, 0.5, 0.2][0.1, 0.7, 0.4]
is[0.5, 0.2, 0.3][0.1, 0.4, 0.6][0.8, 0.1, 0.2]
  1. Multiply QFear by K for every word in the sequence to get the attention score.

    [0.2, 0.1, 0.8]*[0.3, 0.5, 0.2] = 0.2*0.3 + 0.1*0.5 + 0.8*0.2 = 0.27

    [0.2, 0.1, 0.8]*[0.1, 0.4, 0.6] = 0.2*0.1 + 0.1*0.4 + 0.8*0.6 = 0.54

  2. Scale the attention scores by dividing by the square root of 3 (the vector dimension).

    [0.27/1.732, 0.54/1.732] = [0.156, 0.312]

  3. Apply the softmax function to convert the attention scores into probabilities that add up to 1.

    [e^0.156 / (e^0.156 + e^0.312), e^0.312 / (e^0.156 + e^0.312)] = [0.46, 0.54]

  4. Weight V by the attention scores.

    0.46*[0.1, 0.7, 0.4] = [0.046, 0.322, 0.184]

    0.54*[0.8, 0.1, 0.2] = [0.432, 0.054, 0.108]

  5. Add the weighted values together.

    [0.046+0.432, 0.322+0.054, 0.184+0.108] = [0.478, 0.376, 0.292]

This is the final self-attention score for Fear.

#causal self-attention

Causal self-attention is used in decoder models like GPT. Decoder models predict the next word in a sequence, so it is important to mask the next words to prevent the model from seeing them.

single headQKVcausal maskQ, K, V per head. Future positions masked (gray) so tokens only attend to past.

The mask sets words that should be blocked to -∞, which is converted to ~0 by the softmax.

Try calculating the causal self-attention score by hand in the following sequence.

WordQKV
Fear[0.2, 0.1, 0.8][0.3, 0.5, 0.2][0.1, 0.7, 0.4]
is[0.5, 0.2, 0.3][0.1, 0.4, 0.6][0.8, 0.1, 0.2]
the[0.1, 0.6, 0.3][0.4, 0.2, 0.5][0.2, 0.9, 0.1]
  1. Multiply Qis by K for every word in the sequence to get the attention score.

    [0.5, 0.2, 0.3]*[0.3, 0.5, 0.2] = 0.5*0.3+0.2*0.5+0.3*0.2 = 0.31

    [0.5, 0.2, 0.3]*[0.1, 0.4, 0.6] = 0.5*0.1+0.2*0.4+0.3*0.6 = 0.31

    [0.5, 0.2, 0.3]*[0.4, 0.2, 0.5] = 0.5*0.4+0.2*0.2+0.3*0.5 = 0.39

    The attention score is [0.31, 0.31, 0.39].

  2. Scale the attention scores by dividing by the square root of 3 (the vector dimension).

    [0.31/1.732, 0.31/1.732, 0.39/1.732] = [0.179, 0.179, 0.225]

  3. Set 0.225 to -inf to block attention to the future word (the).

    [0.179, 0.179, -inf]

    The rest of the calculation is the same as self-attention.

#multi-head attention

Multi-head attention (MHA) adds more heads to learn from different perspectives.

head 1Q₁K₁V₁head 2Q₂K₂V₂head 3Q₃K₃V₃Each head has its own Q, K, and V projections.

The calculations are the same as self-attention, but the embedding size is split by the number of heads. Each head independently computes the scaled dot-product attention on their slice of data. At the end, the outputs are combined and multiplied by a weight matrix to blend all the information together.

Try calculating the multi-head attention score by hand in the following sequence.

Head 1

WordQKV
Fear2.02.02.0
is3.03.03.0

Head 2

WordQKV
Fear4.01.03.0
is6.01.54.5
  1. For both heads, multiply QFear by K for every word in the sequence to get the attention score.

    Head 1

    [2.0*2.0, 2.0*3.0] = [4.0, 6.0]

    Head 2

    [4.0*1.0, 4.0*1.5] = [4.0, 6.0]

  2. Scale the attention scores of each head by dividing by the square root of 2 (the vector dimension).

    Head 1

    [4.0/1.414, 6.0/1.414] = [2.828, 4.243]

    Head 2

    [4.0/1.414, 6.0/1.414] = [2.828, 4.243]

  3. Apply the softmax function to convert the attention scores into probabilities that add up to 1.

    Head 1

    [e^2.828 / (e^2.828 + e^4.243), e^4.243 / (e^2.828 + e^4.243)] = [0.195, 0.804]

    Head 2

    [e^2.828 / (e^2.828 + e^4.243), e^4.243 / (e^2.828 + e^4.243)] = [0.195, 0.804]

  4. Weight V by the attention scores.

    Head 1

    0.195*2.0+0.804*3.0 = 2.802

    Head 2

    0.195*3.0+0.804*4.5 = 4.203

  5. Concatenate the outputs of each head and multiply by a transposed weight matrix (Wᵀ) to combine the information.

    [2.802, 4.203]*[1.0, 1.0]ᵀ = 7.005

This is the final multi-head attention score for Fear.

#multi-query attention

Multi-query attention (MQA) is the same as MHA except every Q shares the same K, V.

head 1Q₁head 2Q₂head 3Q₃head 4Q₄head 5Q₅head 6Q₆KVEach head has its own Q. All heads share a single K and V.

MQA is more memory-efficient and faster at decoding because each head doesn't need to store a separate K, V. This makes an especially big difference for really long sequences.

Try calculating the multi-query attention score by hand in the following sequence.

Head 1

WordQKV
Fear2.01.03.0
is3.01.54.5

Head 2

WordQKV
Fear4.01.03.0
is6.01.54.5
  1. For both heads, multiply QFear by K for every word in the sequence to get the attention score.

    Head 1

    [2.0*1.0, 2.0*1.5] = [2.0, 3.0]

    Head 2

    [4.0*1.0, 4.0*1.5] = [4.0, 6.0]

  2. Scale the attention scores of each head by dividing by the square root of 2 (the vector dimension).

    Head 1

    [2.0/1.414, 3.0/1.414] = [1.414, 2.121]

    Head 2

    [4.0/1.414, 6.0/1.414] = [2.828, 4.243]

  3. Apply the softmax function to convert the attention scores into probabilities that add up to 1.

    Head 1

    [e^1.414 / (e^1.414 + e^2.121), e^2.121 / (e^1.414 + e^2.121)] = [0.33, 0.67]

    Head 2

    [e^2.828 / (e^2.828 + e^4.243), e^4.243 / (e^2.828 + e^4.243)] = [0.195, 0.805]

  4. Weight V by the attention scores.

    Head 1

    0.33*3.0+0.67*4.5 = 4.01

    Head 2

    0.195*3.0+0.805*4.5 = 4.21

  5. Concatenate the outputs of each head and multiply by a weight matrix to combine the information.

    [4.01, 4.21]*[[0.5, 1.0], [1.5, 1.0]] = [8.32, 8.22]

This is the final multi-query attention score for Fear.

#grouped-query attention

Grouped-query attention (GQA) is similar to MHA and MQA except groups of Q share the same K, V. K, V is different for each group.

group 1group 2head 1Q₁head 2Q₂head 3Q₃head 4Q₄K₁V₁head 5Q₅head 6Q₆head 7Q₇head 8Q₈K₂V₂Each head has its own Q. K and V are shared within each group.

GQA is the middle ground between MHA and MQA. It's faster than MHA and more expressive than MQA.

Try calculating the grouped-query attention score by hand in the following sequence.

Group 0, Head 0

WordQKV
Fear2.01.02.0
is3.01.53.0

Group 0, Head 1

WordQKV
Fear4.01.02.0
is6.01.53.0

Group 1, Head 2

WordQKV
Fear1.02.04.0
is1.53.06.0

Group 1, Head 3

WordQKV
Fear3.02.04.0
is4.53.06.0
  1. For each head in group 0, multiply QFear by K for every word in the sequence to get the attention score.

    Group 0

    Head 0 -> [2.0*1.0, 2.0*1.5] = [2.0, 3.0]

    Head 1 -> [4.0*1.0, 4.0*1.5] = [4.0, 6.0]

    Group 1

    Head 2 -> [1.0*2.0, 1.0*3.0] = [2.0, 3.0]

    Head 3 -> [3.0*2.0, 3.0*3.0] = [6.0, 9.0]

  2. Scale the attention scores of each head by dividing by the square root of 2 (the vector dimension).

    Group 0

    Head 0 -> [2.0/1.414, 3.0/1.414] = [1.414, 2.121]

    Head 1 -> [4.0/1.414, 6.0/1.414] = [2.828, 4.243]

    Group 1

    Head 2 -> [2.0/1.414, 3.0/1.414] = [1.414, 2.121]

    Head 3 -> [6.0/1.414, 9.0/1.414] = [4.243, 6.364]

  3. Apply the softmax function to convert the attention scores into probabilities that add up to 1.

    Group 0

    Head 0 -> [e^1.414 / (e^1.414 + e^2.121), e^2.121 / (e^1.414 + e^2.121)] = [0.33, 0.67]

    Head 1 -> [e^2.828 / (e^2.828 + e^4.243), e^4.243 / (e^2.828 + e^4.243)] = [0.195, 0.805]

    Group 1

    Head 2 -> [e^1.414 / (e^1.414 + e^2.121), e^2.121 / (e^1.414 + e^2.121)] = [0.33, 0.67]

    Head 3 -> [e^4.243 / (e^4.243 + e^6.364), e^6.364 / (e^4.243 + e^6.364)] = [0.107, 0.893]

  4. Weight V by the attention scores.

    Group 0

    Head 0 -> [0.33*2.0 + 0.67*3.0] = 2.67

    Head 1 -> [0.195*2.0 + 0.805*3.0] = 2.81

    Group 1

    Head 2 -> [0.33*4.0 + 0.67*6.0] = 5.34

    Head 3 -> [0.107*4.0 + 0.893*6.0] = 5.79

  5. Concatenate the heads of each group and multiply by a weight matrix to combine the information.

    [2.67, 2.81, 5.34, 5.79] * [[1.0, 0.0], [0.5, 0.5], [1.0, 2.0], [0.1, 1.0]] = [9.994, 17.875]

This is the final grouped-query attention score for Fear.

#multi-head latent attention

standardKVKVKVKVMLAmemory savedMLA compresses KV into small latents.

Multi-head latent attention (MLA) stores a compressed version of KV (or latent) to save more space in the cache. At inference, the latent Q and cached latents KV are multiplied by Wcombined to compute attention.

inference(expands)cQ×WUQQcKV×WUKKMLA(stays small)cQ×WUQWUK×cKVWcombined(precomputed)Wcombined absorbs WUQ and WUK.

The usual approach expands the latent KV back to its full size with an up-projection matrix WUK before computing attention. This takes up more memory and negates the compression benefit.

MLA skips this step by "absorbing" or pre-multiplying WUQ and WUK into a single matrix, Wcombined, ahead of time. The latents don't need to be expanded during inference.

with RoPEWUQ×RtTRj×WUKRtTRjchanges per position pairt=0, j=1t=0, j=2t=1, j=2RtTRj sits between WUQ and WUK and changes per token pair.

Rearranging the maths like this makes it incompatible with rotary position embeddings (RoPE) though. RoPE adds a position-dependent rotation (RtTRj) to Q and K that rotates them by their token positions (t,j).

The term, RtTRj, depends on the (t,j) pair. There is no single fixed matrix you can precompute that works for all (t,j) pairs, so it can't be folded into Wcombined. Using RoPE this way would require recomputing the full K from the latent during generation (increases latency) or caching the full RoPE-applied keys (increases memory).

Attention++qCkCqRkRcachecQcKVcacheh×Wcombined×RoPERoPEcontentpositionContent and position are decoupled; cKV and kR are cached.

Decoupling the positional information from the content solves this. The content calculation is the same, but the positional information is applied in a separate stream.

kR is cached, and the content and positional scores are summed when computing attention. This way, positional information flows through small decoupled vectors and content information can still flow through the latents.

#flashattention

FlashAttention changes how the attention calculation is executed to minimize memory traffic to and from the GPU's slower high-bandwidth global memory (HBM). The calculation is done on the faster on-chip shared memory (SRAM) and registers, without fully materializing the entire attention score matrix in HBM.

block 1Q×K₁Tm₁l₁row-wise statsblock 2Q×K₂Tm₂l₂row-wise update× V₁outputO₁× V₂O₂Block 2 uses m₁, l₁ from block 1 to update row-wise softmax stats and rescale before accumulating.
  • Q, K, and V are tiled into blocks that fit in SRAM. Tiling lets the calculation be performed in a CUDA kernel.
  • In the first block, calculate a row-wise max m1 and a sum l1 for the attention scores.
  • After processing the second block, update the row-wise max to m2. Rescale l1 and the partial output before adding the new updates.
  • Repeat for each calculation.

Attention scores are streamed over blocks to avoid storing the entire attention score matrix in memory. The softmax calculation is exactly the same. Only Q, K, V, the final output, and the softmax normalization factor from the forward pass are stored in HBM.

The backward pass uses the same idea to avoid storing the attention probabilities P in memory. P is required for computing gradients with respect to Q, K, and V. FlashAttention recomputes P block-by-block, uses the same softmax statistics from the forward pass (m, l), and accumulates the gradients.

Speed ups increase with sequence length. But for short sequences, attention is less memory-bound and there's less benefit.

#flashattention-2

FlashAttention-2 improves FlashAttention by reducing non-matmul FLOPs and scheduling more work on the GPU.

  • GPUs use specialized compute units, Tensor Cores on Nvidia GPUs, for fast matmuls. In FlashAttention, every block rescales the output with the sum l.

    FlashAttention-2 delays rescaling to the end to avoid many scalar divisions. This reduces the number of non-matmul operations.

FlashAttentionbatch × heads × query blocksbatch 0h0 q0batch 0h1 q0batch 0h2 q0batch 0h3 q0batch 1h0 q0batch 1h1 q0batch 1h2 q0batch 1h3 q0idleidleidleidleidleidleidleidle8 of 16 SMs usedFlashAttention-2batch × heads × query blocksb0 h0q 0b0 h0q 1b0 h1q 0b0 h1q 1b0 h2q 0b0 h2q 1b0 h3q 0b0 h3q 1b1 h0q 0b1 h0q 1b1 h1q 0b1 h1q 1b1 h2q 0b1 h2q 1b1 h3q 0b1 h3q 116 of 16 SMs usedFlashAttention-2 schedules more query blocks in parallel, which improves SM usage for long sequences.
  • FlashAttention parallelizes over batch size x num heads x query blocks. For long sequences with small batches, there are fewer head blocks and streaming multiprocessors (SM) occupancy drops.

    FlashAttention-2 increases tiling and scheduling so more query blocks run in parallel. This boosts SM usage for long sequences.

FlashAttentionsplit K, V across warpsQ (shared)warp 0Q × K₀,V₀warp 1Q × K₁,V₁warp 2Q × K₂,V₂shared memory (write partial results)sync + reduceFlashAttention-2split Q across warpsK, V (shared)warp 0Q₀ × K,Vwarp 1Q₁ × K,Vwarp 2Q₂ × K,VO₀ (no sync)O₁ (no sync)O₂ (no sync)no shared memory writesFlashAttention-2 assigns Q blocks to warps instead of K,V, reducing shared memory writes and synchronization.
  • Warps are groups of threads in a thread block. FlashAttention split the work on K, V and Q is accessible by all warps. After computing attention, warps write the partial results to shared memory, synchronize, and reduce.

    FlashAttention-2 splits the work on Q and K, V is accessible by all warps. Each warp computes a partial result that doesn't require synchronization until the end. Assigning different Q blocks to different warps minimizes shared memory traffic.

#flashattention-3

FlashAttention-3 optimizes for newer hardware, like H100 GPUs, that support asynchrony. Specialized compute units run in parallel. Tensor cores can do fast matmuls, Tensor Memory Accelerator (TMA) loads data from HBM, and CUDA cores can perform slower computations like softmax.

FlashAttention-2all warps do same worktime →warp 0load K₁V₁computeload K₂V₂warp 1load K₁V₁computeload K₂V₂warp 2load K₁V₁computeload K₂V₂⚠ sequential: load → compute → loadFlashAttention-3specialized producer/consumertime →producerload K₁V₁load K₂V₂load K₃V₃TMAconsumerwaitattn K₁V₁attn K₂V₂tensor cores✓ overlapped: load while computingFlashAttention-3 overlaps memory transfers and computation using specialized producer/consumer warps.
  • Warps performed the same work in FlashAttention-2, loading and computing.

    FlashAttention-3 specializes warps to different tasks. Producer warps load K, V blocks from HBM to SRAM with TMA. Consumer warps compute attention using data in SRAM with tensor cores. The producer can prefetch the next block while the consumer is working on the current block. It overlaps loading and computation.

step 1step 2step 3step 4WGMMAS_next = Q @ Kⱼasync (no block)O += P @ Vⱼ₋₁async accumulatewgmma_waitsync pointoverlapsoftmax(S_next issuing)m = max(m, rowmax(S))P = exp(S - m)(O accumulating + wait)WGMMA issues async matmuls, allowing softmax to run in parallel on the previous block's scores.
  • The output O depends on P. P is calculated from the softmax(S), and S depends on the attention score calculation.

    FlashAttention-3 uses Warp Group Matrix Multiply-Accumulate (WGMMA) to asynchronously perform the computation. It pipelines blocks so while softmax executes on one block of the scores matrix, WGMMA executes to compute the next block. There are two pipelines running in parallel, WGMMA executions and non-WGMMA executions.

  • FP8 quantization is sensitive to outlier features in large language models (LLMs). A few dimensions in hidden states can have very large values. When quantized to FP8, the issue manifests as precision loss.

global scaleper-block scale0.50.30.20.452.047.0scale ≈ 0.116block 1[0.5, 0.3, 0.2, 0.4]448268179357scale₁ ≈ 0.001block 2[47, 52, 41, 48]405448353414scale₂ ≈ 0.116Each block computes its own scale so precision is preserved locally. Both use the full FP8 range.

FlashAttention-3 uses a different scale per block. Each block can use the full FP8 range for its values. The outlier block gets a big scale and the normal blocks get a tiny scale. This way, precision is preserved locally.

input0.50.30.247.00.1×random orthogonal RRTR = I (preserves dot products)=output8.29.17.88.58.4Multiplying by random orthogonal matrix R redistributes the outlier across all dimensions while preserving dot products.

If an outlier is in a normal block, FlashAttention-3 fixes this by spreading the outlier across all dimensions before quantization. It multiplies Q, K by a random orthogonal matrix to redistribute the outlier. This way, no single dimension dominates a block.

#summary

Self-attention is a weighted sum of contextual information of all words. The weights reflect each surrounding word's relevance to the current word.

But attention is memory-bound. New attention algorithms optimize how the attention calculation is executed, whether it's sharing the same K, V or streaming the attention function.

#resources