Attention Mechanism Deep Dive¶
✨ Bit: Attention is the beating heart of every transformer. Understanding its variants — MHA, GQA, MQA, and FlashAttention — is the difference between reading research papers and understanding them. This note goes beyond "Q·Kᵀ/√d" to the engineering reality of memory, speed, and scale.
★ TL;DR¶
- What: A complete treatment of attention variants, their memory/compute tradeoffs, KV-cache mechanics, and hardware-efficient implementations (FlashAttention)
- Why: Attention is the bottleneck for both training and inference cost. Understanding attention variants is essential for model selection, optimization, and architecture design.
- Key point: The evolution from MHA → GQA → MQA reduced KV-cache memory by 8-96×, enabling longer contexts and cheaper inference. FlashAttention reduced attention's memory footprint from O(n²) to O(n).
★ Overview¶
Definition¶
Attention computes a weighted combination of values (V) based on the compatibility between queries (Q) and keys (K). The mechanism allows the model to focus on relevant parts of the input regardless of distance.
Prerequisites¶
- Attention Mechanism — foundational concepts
- Transformers — architecture context
- Inference Optimization
★ Deep Dive¶
Attention Mathematics¶
SCALED DOT-PRODUCT ATTENTION:
Attention(Q, K, V) = softmax(Q·Kᵀ / √d_k) · V
Where:
Q ∈ ℝ^(n × d_k) Query matrix (n tokens × d_k dimensions)
K ∈ ℝ^(n × d_k) Key matrix
V ∈ ℝ^(n × d_v) Value matrix
d_k Key dimension (typically d_model / n_heads)
STEP BY STEP:
1. Q·Kᵀ → (n × n) attention scores O(n²·d_k)
2. / √d_k → scale to prevent large logits
3. + mask → causal mask for autoregressive
4. softmax → normalize to probabilities O(n²)
5. × V → weighted value aggregation O(n²·d_v)
TOTAL COMPLEXITY: O(n² · d)
TOTAL MEMORY: O(n²) for the attention weights matrix
Attention Variant Evolution¶
Multi-Head Attention (MHA) — Original (2017)
│ Each head has its own Q, K, V projections
│ H heads × (d_q + d_k + d_v) parameters
│ KV-cache: H × n × d_k per layer
│
├──► Grouped-Query Attention (GQA) — LLaMA 2+ (2023)
│ Groups of query heads share K, V projections
│ G groups (G < H), each group shares K, V
│ KV-cache: G × n × d_k per layer (G/H reduction)
│
└──► Multi-Query Attention (MQA) — PaLM (2022)
ALL query heads share ONE set of K, V
KV-cache: 1 × n × d_k per layer (H× reduction!)
Slight quality loss, massive memory savings
KV-Cache Memory Math¶
KV-CACHE SIZE PER TOKEN:
Per layer: 2 × n_heads_kv × d_head × precision_bytes
Total: per_layer × n_layers × sequence_length
EXAMPLE: LLaMA 3.1 70B
n_layers = 80
n_heads = 64 (query), n_heads_kv = 8 (GQA, 8 groups)
d_head = 128
precision = 2 bytes (bf16)
Per token per layer = 2 × 8 × 128 × 2 = 4,096 bytes = 4 KB
Per token total = 4 KB × 80 layers = 320 KB
SEQUENCE COSTS:
1K tokens → 320 MB
8K tokens → 2.5 GB
32K tokens → 10 GB
128K tokens → 40 GB ← exceeds many GPUs!
IF MHA (64 KV heads instead of 8):
Per token = 2 × 64 × 128 × 2 × 80 = 2.5 MB per token
128K tokens → 320 GB ← impossible!
GQA SAVINGS: 8× reduction in KV-cache = feasible long context
FlashAttention: IO-Aware Attention¶
STANDARD ATTENTION:
1. Compute full n×n attention matrix in HBM → O(n²) memory
2. Apply softmax
3. Multiply by V
Problem: n×n matrix doesn't fit in SRAM for long sequences
FLASH ATTENTION:
1. Tile Q, K, V into blocks that fit in SRAM
2. Compute attention block-by-block (never materialize full n×n)
3. Use online softmax (rescale running max)
4. Accumulate output incrementally
Result: O(n) memory, 2-4× faster wall-clock time
WHY IT'S FASTER DESPITE SAME FLOPS:
- Attention is memory-bound, not compute-bound
- FlashAttention minimizes HBM reads/writes
- Keeping data in SRAM (fast) instead of HBM (slow)
FlashAttention-2: Further optimized, 2× of FA-1
FlashAttention-3: Hopper GPU optimizations, FP8 support
★ Code & Implementation¶
Attention Variants Comparison¶
# pip install torch>=2.0
# ⚠️ Last tested: 2026-04 | Requires: torch>=2.0
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""Standard scaled dot-product attention."""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V), weights
# Compare KV-cache sizes across attention variants
def kv_cache_size_bytes(
n_layers: int, seq_len: int, n_kv_heads: int,
d_head: int, dtype_bytes: int = 2,
) -> int:
"""Calculate KV-cache memory in bytes."""
return 2 * n_layers * seq_len * n_kv_heads * d_head * dtype_bytes
# LLaMA 3.1 70B comparison
configs = {
"MHA (64 KV heads)": {"n_kv_heads": 64}, # Original attention
"GQA (8 KV heads)": {"n_kv_heads": 8}, # Grouped-query (actual)
"MQA (1 KV head)": {"n_kv_heads": 1}, # Multi-query
}
print("KV-Cache Memory for LLaMA 3.1 70B (80 layers, d_head=128, bf16):")
print(f"{'Variant':<22} {'1K tokens':>12} {'32K tokens':>12} {'128K tokens':>12}")
print("-" * 60)
for name, cfg in configs.items():
sizes = []
for seq_len in [1024, 32768, 131072]:
size = kv_cache_size_bytes(80, seq_len, cfg["n_kv_heads"], 128, 2)
sizes.append(f"{size / 1e9:.1f} GB")
print(f"{name:<22} {sizes[0]:>12} {sizes[1]:>12} {sizes[2]:>12}")
# Expected output:
# KV-Cache Memory for LLaMA 3.1 70B (80 layers, d_head=128, bf16):
# Variant 1K tokens 32K tokens 128K tokens
# ------------------------------------------------------------
# MHA (64 KV heads) 2.6 GB 83.9 GB 335.5 GB
# GQA (8 KV heads) 0.3 GB 10.5 GB 41.9 GB
# MQA (1 KV head) 0.0 GB 1.3 GB 5.2 GB
◆ Production Failure Modes¶
| Failure | Symptoms | Root Cause | Mitigation |
|---|---|---|---|
| KV-cache OOM | Inference crashes on long sequences | KV-cache exceeds GPU memory | Use GQA/MQA models, quantize KV-cache, limit max_seq_len |
| Attention-bound latency | Prefill latency scales quadratically with input length | O(n²) attention, long inputs | Use FlashAttention, consider chunked prefill |
| Lost-in-the-middle | Model ignores information in middle of long context | Attention weights concentrate on beginning/end | Structure input with important info at start/end |
○ Interview Angles¶
- Q: What is Multi-Query Attention and why does it matter?
- A: In standard Multi-Head Attention, each attention head has its own K and V projections — meaning the KV-cache scales linearly with the number of heads. Multi-Query Attention shares a single K, V pair across all query heads. This reduces KV-cache by the number of heads (e.g., 64× for LLaMA with 64 heads), enabling much longer context windows and higher batch sizes during inference. Grouped-Query Attention (GQA) is the practical middle ground — using 8 KV groups instead of 64 or 1 — giving most of the memory savings with minimal quality loss. This is what LLaMA 3 uses.
★ Connections¶
| Relationship | Topics |
|---|---|
| Builds on | Attention Mechanism, Transformers |
| Leads to | Inference Optimization, State Space Models |
| Compare with | SSM (Mamba), linear attention, sparse attention |
| Cross-domain | HPC, GPU programming, hardware-software co-design |
★ Recommended Resources¶
| Type | Resource | Why |
|---|---|---|
| 📄 Paper | Vaswani et al. "Attention Is All You Need" (2017) | Original attention mechanism |
| 📄 Paper | Dao et al. "FlashAttention" (2022) | IO-aware attention that changed inference |
| 📄 Paper | Ainslie et al. "GQA: Grouped-Query Attention" (2023) | The GQA paper used by LLaMA 2/3 |
| 🎥 Video | 3Blue1Brown — "Attention in Transformers" | Best visual explanation |
◆ Hands-On Exercises¶
Exercise 1: Implement GQA and Compare KV-Cache Size¶
Goal: Build GQA attention and measure memory savings vs standard MHA Time: 30 minutes Steps: 1. Implement MHA with 32 heads 2. Implement GQA with 32 query heads and 8 KV heads 3. Compare KV-cache memory at 2048 sequence length 4. Verify output quality is similar on a test batch Expected Output: Memory comparison table showing 4x KV-cache reduction with GQA
★ Sources¶
- Vaswani et al. "Attention Is All You Need" (2017)
- Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (2022)
- Ainslie et al. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (2023)
- Shazeer "Fast Transformer Decoding: One Write-Head is All You Need" (2019)