Attention Mechanism¶
✨ Bit: Attention in AI is like attention in humans — you don't read every word equally; you focus on what matters for understanding the current thing.
★ TL;DR¶
- What: A mechanism that lets each element in a sequence dynamically focus on relevant parts of the entire input
- Why: Solves the bottleneck of fixed-size representations in sequence models. THE key innovation behind Transformers
- Key point: Query-Key-Value triplet: "What am I looking for?" matches against "What do I contain?" to retrieve "What should I return?"
★ Overview¶
Definition¶
Attention is a mechanism that computes a weighted combination of values (V), where the weights are determined by the compatibility between a query (Q) and keys (K). It allows a model to "focus" on different parts of the input when producing each part of the output.
Scope¶
Covers: Self-attention, cross-attention, multi-head attention, and modern variants (GQA, Flash Attention, MQA). For the full Transformer architecture, see Transformers.
Significance¶
- Before attention: Encoder compressed entire sequence into ONE fixed vector → information bottleneck
- After attention: Every output position can directly access any input position → no bottleneck
Prerequisites¶
- Linear Algebra For Ai — matrix multiplication, dot products
- Neural Networks — basic concepts
★ Deep Dive¶
The QKV Intuition¶
Think of attention as a search engine:
You have a QUERY → "What information do I need?"
Matched against KEYS → "What does each position contain?"
Retrieves VALUES → "Here's the actual content"
Example: Parsing "The cat sat because it was tired"
When processing "it":
Query("it") · Key("cat") = HIGH score → "it" refers to "cat"
Query("it") · Key("sat") = LOW score → "it" doesn't refer to "sat"
Result: "it" attends strongly to "cat" and gets its information
Step-by-Step Computation¶
Input: X (sequence of token embeddings, shape: [seq_len, d_model])
Step 1: Project into Q, K, V
Q = X · W_Q (shape: [seq_len, d_k])
K = X · W_K (shape: [seq_len, d_k])
V = X · W_V (shape: [seq_len, d_v])
Step 2: Compute attention scores
scores = Q · K^T (shape: [seq_len, seq_len])
Step 3: Scale
scores = scores / √d_k (prevent exploding gradients)
Step 4: Softmax (normalize to probabilities)
weights = softmax(scores) (each row sums to 1)
Step 5: Weighted sum of values
output = weights · V (shape: [seq_len, d_v])
Visual Example¶
"The" "cat" "sat" "on" "it"
"The" [ 0.6 0.2 0.1 0.05 0.05 ]
"cat" [ 0.1 0.7 0.1 0.05 0.05 ]
"sat" [ 0.05 0.3 0.5 0.1 0.05 ]
"on" [ 0.05 0.1 0.2 0.6 0.05 ]
"it" [ 0.05 0.6 0.1 0.05 0.2 ] ← "it" attends most to "cat"
^ Each row shows WHERE that token is "looking"
^ Values sum to 1.0 (softmax)
Multi-Head Attention¶
Instead of one attention, compute h parallel attentions with different learned projections:
Why multiple heads? - Head 1 might learn: "who is the subject?" - Head 2 might learn: "what is the verb?" - Head 3 might learn: "positional proximity" - Together: richer representation than any single attention
Types of Attention¶
| Type | Q from | K,V from | Use Case |
|---|---|---|---|
| Self-Attention | Same sequence | Same sequence | Token-to-token within input |
| Cross-Attention | Decoder | Encoder output | Decoder attending to encoder (translation) |
| Causal/Masked | Same sequence | Same sequence (masked) | Autoregressive generation (can't see future) |
Causal Masking (Critical for LLMs)¶
In generation, token at position i should only attend to positions ≤ i (can't see the future):
Mask:
[1, -∞, -∞, -∞] "The" can only see "The"
[1, 1, -∞, -∞] "cat" can see "The", "cat"
[1, 1, 1, -∞] "sat" can see "The", "cat", "sat"
[1, 1, 1, 1] "on" can see everything before it
Applied BEFORE softmax: e^(-∞) = 0, so masked positions get zero weight
Modern Variants¶
| Variant | What It Does | Why |
|---|---|---|
| MHA (Multi-Head) | Full Q,K,V per head | Original, most expressive |
| MQA (Multi-Query) | Shared K,V across heads, unique Q | 10x faster inference, slight quality drop |
| GQA (Grouped Query) | Groups of heads share K,V | Best of both: fast + quality. Used by LLaMA 2+ |
| Flash Attention | Tiling + recomputation to avoid materializing full attention matrix | 2-4x faster, way less memory |
| RoPE | Rotary Position Embeddings baked into Q,K | Better extrapolation to unseen lengths |
| Sliding Window | Only attend to nearby tokens within a window | Handles very long sequences (Mistral) |
◆ Formulas & Equations¶
| Name | Formula | Variables | Use |
|---|---|---|---|
| Scaled Dot-Product | $$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ | Q,K,V matrices ; d_k = key dimension | Core attention |
| Multi-Head | $$\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^O$$ | h=num heads, W^O=output projection | Parallel attention |
| Complexity | $$O(n^2 \cdot d)$$ | n=sequence length, d=dimension | Time & memory cost |
◆ Strengths vs Limitations¶
| ✅ Strengths | ❌ Limitations |
|---|---|
| Captures long-range dependencies | O(n²) — quadratic with sequence length |
| Fully parallelizable | Large memory footprint for long sequences |
| Interpretable (attention weights show what model "looks at") | Attention maps don't always reflect causal reasoning |
| Works across modalities (text, image, audio) | Still needs positional encoding (no inherent order) |
◆ Quick Reference¶
Attention(Q,K,V) = softmax(QKᵀ/√d_k) · V
Multi-Head: Run h parallel attentions, concat, project
Causal mask: Upper triangle = -∞ (can't see future)
Complexity: O(n²·d) per layer
Modern defaults:
- GQA (not full MHA) for efficiency
- Flash Attention for memory
- RoPE for positions
- Sliding window for very long contexts
○ Interview Angles¶
- Q: Why divide by √d_k in attention?
-
A: Without it, for large d_k, dot products become huge → softmax saturates → near-zero gradients. Scaling keeps variance at ~1.
-
Q: What's the difference between MHA, MQA, and GQA?
-
A: MHA: separate K,V per head (most expressive, slowest). MQA: one shared K,V (fastest, some quality loss). GQA: groups of heads share K,V (good balance). LLaMA 2+ uses GQA.
-
Q: How does Flash Attention improve efficiency without changing the math?
- A: It tiles the computation to fit in SRAM (fast cache), avoiding materialization of the full n×n attention matrix in slow HBM (GPU memory). Same result, ~2-4x faster.
★ Code & Implementation¶
Scaled Dot-Product Attention from Scratch¶
# ⚠️ Last tested: 2026-04 | Requires: torch>=2.3
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, causal_mask=False):
"""
Q: (batch, seq_len, d_k)
K: (batch, seq_len, d_k)
V: (batch, seq_len, d_v)
"""
d_k = Q.size(-1)
# Step 1+2+3: QK^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if causal_mask:
# Mask future positions (upper triangle) with -inf
seq_len = Q.size(-2)
mask = torch.triu(torch.ones(seq_len, seq_len, device=Q.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
# Step 4: Softmax
weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of V
output = torch.matmul(weights, V)
return output, weights
# Verify against PyTorch built-in
Q = torch.randn(1, 5, 32) # batch=1, seq_len=5, d_k=32
K = torch.randn(1, 5, 32)
V = torch.randn(1, 5, 32)
custom_out, weights = scaled_dot_product_attention(Q, K, V, causal_mask=True)
builtin_out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print(f"Max diff vs built-in: {(custom_out - builtin_out).abs().max().item():.2e}")
# Should be < 1e-5 (numerical precision only)
print(f"Attention weights shape: {weights.shape}") # (1, 5, 5)
Visualize Attention Heads (HuggingFace)¶
# pip install transformers>=4.40 matplotlib>=3.8
# ⚠️ Last tested: 2026-04 | Requires: transformers>=4.40, matplotlib
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModel.from_pretrained("google/gemma-2-2b", output_attentions=True)
sentence = "The cat sat on the mat because it was tired."
inputs = tokenizer(sentence, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: tuple of (batch, heads, seq_len, seq_len) per layer
attn_layer0 = outputs.attentions[0][0] # Layer 0, batch 0: (heads, seq_len, seq_len)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# Plot head 0 from layer 0
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(attn_layer0[0].numpy(), cmap="Blues", vmin=0, vmax=1)
ax.set_xticks(range(len(tokens))); ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticks(range(len(tokens))); ax.set_yticklabels(tokens)
ax.set_title("Layer 0, Head 0 Attention")
plt.tight_layout()
plt.savefig("attention_head0.png", dpi=150)
print("Saved attention_head0.png")
★ Connections¶
| Relationship | Topics |
|---|---|
| Builds on | Linear Algebra For Ai, Neural Networks |
| Leads to | Transformers, Large Language Models (LLMs) |
| Compare with | Recurrence (RNNs), Convolution (CNNs) |
| Cross-domain | Vision Transformers (ViT), Graph Attention Networks |
◆ Production Failure Modes¶
| Failure | Symptoms | Root Cause | Mitigation |
|---|---|---|---|
| Attention sink | First token receives disproportionate attention regardless of content | Model uses first position as default attention target | StreamingLLM sink tokens, attention bias corrections |
| Head redundancy | Many heads learn identical patterns, wasting capacity | No head diversity regularization | Head pruning, auxiliary diversity loss, structured dropout |
| Cross-attention misalignment | Decoder attends to wrong encoder regions | Insufficient data diversity, positional confusion | Alignment training, attention supervision |
◆ Hands-On Exercises¶
Exercise 1: Visualize Attention Patterns¶
Goal: Extract and visualize attention weights to understand model behavior
Time: 30 minutes
Steps:
1. Load a pretrained BERT or GPT-2
2. Run a sentence through with output_attentions=True
3. Plot attention heatmaps for 3 different heads
4. Identify syntactic vs semantic attention patterns
Expected Output: Attention heatmaps showing different heads capture different linguistic patterns
★ Recommended Resources¶
| Type | Resource | Why |
|---|---|---|
| 📄 Paper | Vaswani et al. "Attention Is All You Need" (2017) — Section 3 | Precise mathematical definition of scaled dot-product attention |
| 🎥 Video | 3Blue1Brown — "Attention in Transformers, Visually Explained" | Best visual intuition for Q/K/V and multi-head attention |
| 📄 Paper | Dao et al. "FlashAttention" (2022) | IO-aware attention kernels — essential for understanding modern efficiency |
| 📘 Book | "Build a Large Language Model (From Scratch)" by Sebastian Raschka (2024), Ch 3 | Implement attention from scratch in PyTorch |
★ Sources¶
- Vaswani et al., "Attention Is All You Need" (2017) — https://arxiv.org/abs/1706.03762
- Bahdanau et al., "Neural Machine Translation by Jointly Learning to Align and Translate" (2014) — Original attention paper
- Jay Alammar, "The Illustrated Transformer" — https://jalammar.github.io/illustrated-transformer/
- Tri Dao, "Flash Attention" (2022) — https://arxiv.org/abs/2205.14135