Attention please
@stevhliu|Aug 10, 2025 (5 months ago)41 views
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
It's amazing that an idea so simple has such a tremendous and lasting impact. This idea is described in only two sentences, but it is consumed by millions of people using large language models daily.
This post tries to very simply and very plainly explain how attention, and variants of it, work.
#self-attention
Self-attention (scaled dot-product attention) computes how much attention each word in a sequence assigns to every other word in that same sequence. It's what makes transformers contextually aware.
You need 3 matrices to compute self-attention. These matrices are created by multiplying word embeddings by 3 weight matrices (Wk, Wq, Wv).
-
Query (Q) is the information you're looking for. It is compared to every K (including itself) to calculate how much attention to pay to each word.
-
Key (K) is the information a word contains. It is multiplied (dot product) by Q to get the attention scores for each word.
The scores are scaled by dividing by the square root of the vector dimension. For example, if the dimension is 64, divide the attention scores by 8. Scaling prevents the attention scores from becoming too large or too small.
Scores are converted into probabilities that add up to 1 with the softmax function. Scaling also smooths out the softmax function by preventing any one value from getting too large and overwhelming the rest.
-
Value (V) is the information each word contributes. It weights each word with the attention score to determine what information every other word offers.
Try calculating the self-attention score by hand in the following sequence to really get a feel for how it works.
#causal self-attention
Causal self-attention is used in decoder models like GPT. Decoder models predict the next word in a sequence, so it is important to mask the next words to prevent the model from seeing them.
The mask sets words that should be blocked to -∞, which is converted to ~0 by the softmax.
Try calculating the causal self-attention score by hand in the following sequence.
#multi-head attention
Multi-head attention (MHA) adds more heads to learn from different perspectives.
The calculations are the same as self-attention, but the embedding size is split by the number of heads. Each head independently computes the scaled dot-product attention on their slice of data. At the end, the outputs are combined and multiplied by a weight matrix to blend all the information together.
Try calculating the multi-head attention score by hand in the following sequence.
#multi-query attention
Multi-query attention (MQA) is the same as MHA except every Q shares the same K, V.
MQA is more memory-efficient and faster at decoding because each head doesn't need to store a separate K, V. This makes an especially big difference for really long sequences.
Try calculating the multi-query attention score by hand in the following sequence.
#grouped-query attention
Grouped-query attention (GQA) is similar to MHA and MQA except groups of Q share the same K, V. K, V is different for each group.
GQA is the middle ground between MHA and MQA. It's faster than MHA and more expressive than MQA.
Try calculating the grouped-query attention score by hand in the following sequence.
#multi-head latent attention
Multi-head latent attention (MLA) stores a compressed version of KV (or latent) to save more space in the cache. At inference, the latent Q and cached latents KV are multiplied by Wcombined to compute attention.
The usual approach expands the latent KV back to its full size with an up-projection matrix WUK before computing attention. This takes up more memory and negates the compression benefit.
MLA skips this step by "absorbing" or pre-multiplying WUQ and WUK into a single matrix, Wcombined, ahead of time. The latents don't need to be expanded during inference.
Rearranging the maths like this makes it incompatible with rotary position embeddings (RoPE) though. RoPE adds a position-dependent rotation (RtTRj) to Q and K that rotates them by their token positions (t,j).
The term, RtTRj, depends on the (t,j) pair. There is no single fixed matrix you can precompute that works for all (t,j) pairs, so it can't be folded into Wcombined. Using RoPE this way would require recomputing the full K from the latent during generation (increases latency) or caching the full RoPE-applied keys (increases memory).
Decoupling the positional information from the content solves this. The content calculation is the same, but the positional information is applied in a separate stream.
kR is cached, and the content and positional scores are summed when computing attention. This way, positional information flows through small decoupled vectors and content information can still flow through the latents.
#flashattention
FlashAttention changes how the attention calculation is executed to minimize memory traffic to and from the GPU's slower high-bandwidth global memory (HBM). The calculation is done on the faster on-chip shared memory (SRAM) and registers, without fully materializing the entire attention score matrix in HBM.
- Q, K, and V are tiled into blocks that fit in SRAM. Tiling lets the calculation be performed in a CUDA kernel.
- In the first block, calculate a row-wise max m1 and a sum l1 for the attention scores.
- After processing the second block, update the row-wise max to m2. Rescale l1 and the partial output before adding the new updates.
- Repeat for each calculation.
Attention scores are streamed over blocks to avoid storing the entire attention score matrix in memory. The softmax calculation is exactly the same. Only Q, K, V, the final output, and the softmax normalization factor from the forward pass are stored in HBM.
The backward pass uses the same idea to avoid storing the attention probabilities P in memory. P is required for computing gradients with respect to Q, K, and V. FlashAttention recomputes P block-by-block, uses the same softmax statistics from the forward pass (m, l), and accumulates the gradients.
Speed ups increase with sequence length. But for short sequences, attention is less memory-bound and there's less benefit.
#flashattention-2
FlashAttention-2 improves FlashAttention by reducing non-matmul FLOPs and scheduling more work on the GPU.
-
GPUs use specialized compute units, Tensor Cores on Nvidia GPUs, for fast matmuls. In FlashAttention, every block rescales the output with the sum l.
FlashAttention-2 delays rescaling to the end to avoid many scalar divisions. This reduces the number of non-matmul operations.
-
FlashAttention parallelizes over batch size x num heads x query blocks. For long sequences with small batches, there are fewer head blocks and streaming multiprocessors (SM) occupancy drops.
FlashAttention-2 increases tiling and scheduling so more query blocks run in parallel. This boosts SM usage for long sequences.
-
Warps are groups of threads in a thread block. FlashAttention split the work on K, V and Q is accessible by all warps. After computing attention, warps write the partial results to shared memory, synchronize, and reduce.
FlashAttention-2 splits the work on Q and K, V is accessible by all warps. Each warp computes a partial result that doesn't require synchronization until the end. Assigning different Q blocks to different warps minimizes shared memory traffic.
#flashattention-3
FlashAttention-3 optimizes for newer hardware, like H100 GPUs, that support asynchrony. Specialized compute units run in parallel. Tensor cores can do fast matmuls, Tensor Memory Accelerator (TMA) loads data from HBM, and CUDA cores can perform slower computations like softmax.
-
Warps performed the same work in FlashAttention-2, loading and computing.
FlashAttention-3 specializes warps to different tasks. Producer warps load K, V blocks from HBM to SRAM with TMA. Consumer warps compute attention using data in SRAM with tensor cores. The producer can prefetch the next block while the consumer is working on the current block. It overlaps loading and computation.
-
The output O depends on P. P is calculated from the softmax(S), and S depends on the attention score calculation.
FlashAttention-3 uses Warp Group Matrix Multiply-Accumulate (WGMMA) to asynchronously perform the computation. It pipelines blocks so while softmax executes on one block of the scores matrix, WGMMA executes to compute the next block. There are two pipelines running in parallel, WGMMA executions and non-WGMMA executions.
-
FP8 quantization is sensitive to outlier features in large language models (LLMs). A few dimensions in hidden states can have very large values. When quantized to FP8, the issue manifests as precision loss.
FlashAttention-3 uses a different scale per block. Each block can use the full FP8 range for its values. The outlier block gets a big scale and the normal blocks get a tiny scale. This way, precision is preserved locally.
If an outlier is in a normal block, FlashAttention-3 fixes this by spreading the outlier across all dimensions before quantization. It multiplies Q, K by a random orthogonal matrix to redistribute the outlier. This way, no single dimension dominates a block.
#summary
Self-attention is a weighted sum of contextual information of all words. The weights reflect each surrounding word's relevance to the current word.
But attention is memory-bound. New attention algorithms optimize how the attention calculation is executed, whether it's sharing the same K, V or streaming the attention function.
#resources
- Attention Is All You Need
- Fast Transformer Decoding: One Write-Head is All You Need
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
- FlashAttention, FlashAttention-2, FlashAttention-3
- Youtube videos by 3Blue1Brown and Andrej Karpathy