Understanding Attention: A Code-First Journey Through Transformers
Transformers underpin most modern language models. This post builds them from scratch, starting simple and adding complexity step by step. The goal is to understand not just what attention does, but how it works at the tensor level.
The code examples are meant to be run interactively. Appendix A. Virtual Environment Setup and B. REPL setup are provided for a quick environment setup. I recommend reading the code carefully, pasting into your Python shell and playing with it to get a feel for attention and multi-headed attention. Note that the example outputs shown are illustrative — your values will differ on each run since the random seed varies.
Let’s begin.
The Simplest Attention Mechanism
Tokens and Embeddings
Let’s start with the absolute minimum — three tokens (words) in a sequence (sentence), each token represented by a 4-dimensional vector (embedding):
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Create a simple sequence: 3 tokens, 4 dimensions each
seq_len = 3
embed_dim = 4
# Random embeddings for our tokens
x = torch.randn(seq_len, embed_dim)
print(f"Input shape: {x.shape}")
print(f"Input:\n{x}")Input shape: torch.Size([3, 4])
Input:
tensor([[ 0.3367, 0.1288, 0.2345, 0.2303],
[-1.1229, -0.1863, 2.2082, -0.6380],
[ 0.4617, 0.2674, 0.5349, 0.8094]])torch.Size([3, 4]) means 3 tokens (“words”) with each token represented by 4 features (a 4-dimensional embedding).
Queries, Keys, and Values
In attention, each token decides what to pay attention to. This is done through three projections:
| Projection | Math | Question |
|---|---|---|
| Query | What am I looking for? | |
| Key | What information do I contain? | |
| Value | What information will I actually send? |
Representing in code:
# Simple linear projections (no bias for clarity)
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)
# Create Q, K, V
Q = W_q(x) # (3, 4)
K = W_k(x) # (3, 4)
V = W_v(x) # (3, 4)
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")Q shape: torch.Size([3, 4])
K shape: torch.Size([3, 4])
V shape: torch.Size([3, 4])Each nn.Linear is a matrix multiplication, or written another way:
Q = x @ W_q.T # (3,4) @ (4,4) = (3,4)Thus, each token’s embedding gets transformed with these learned weights. The matrices are square with the same dimensions as the embedding dimension (4 features). After multiplying with these weights, the final shape of the query, key and value matrices are the same as the input data. The meaning is changed though through these learned weights.
Q, K, and V are produced using linear layers because the model needs simple, learnable projections of the same token embedding but in different subspaces — one for querying, one for matching, and one for carrying information forward.
A linear transformation is ideal because attention relies on dot-product similarity, which assumes a linear geometric structure; extra nonlinearities would distort that relationship.
Bias terms are omitted because 1) they don’t help with similarity scoring but shift all vectors uniformly 2) bias adds parameters without improving the dot-product matching that drives attention 3) omitting bias also makes multi-head splitting easier and symmetric and hence more parallelizable which matters when you have billions of parameters to learn.
Computing Attention Scores
Next, we compute how much attention each token should pay to every other token:
# Attention scores: Q @ K^T
# This gives us: for each query position, how much does it match each key?
scores = Q @ K.T # (3, 4) @ (4, 3) = (3, 3)
print(f"Scores shape: {scores.shape}")
print(f"Scores:\n{scores}")Scores shape: torch.Size([3, 3])
Scores:
tensor([[ 0.5423, -0.2819, 0.3764],
[-0.2819, 0.8934, -0.1234],
[ 0.3764, -0.1234, 0.7856]])Matrix multiplication needs the column dimension of the first matrix to be equal to the row dimension of the second matrix. Since the embedding dimension is always fixed (the 4-dimensional “embedding” for each token), we align and along this dimension. So 3 tokens go in and we get a 3x3 matrix of scores in a sort-of lookup table. Higher values mean pay more attention to this other token, and lower mean less.
Let’s visualize what this means:
print("\nAttention score interpretation:")
print(" Key Position")
print(" Token0 Token1 Token2")
print(f"Query Token0 {scores[0,0]:6.2f} {scores[0,1]:6.2f} {scores[0,2]:6.2f}")
print(f"Query Token1 {scores[1,0]:6.2f} {scores[1,1]:6.2f} {scores[1,2]:6.2f}")
print(f"Query Token2 {scores[2,0]:6.2f} {scores[2,1]:6.2f} {scores[2,2]:6.2f}")Attention score interpretation:
Key Position
Token0 Token1 Token2
Query Token0 0.54 -0.28 0.38
Query Token1 -0.28 0.89 -0.12
Query Token2 0.38 -0.12 0.79
Scaling for Stability
Raw scores can get very large, which causes problems in softmax. We scale by the square root of the embedding dimension:
# Scale by sqrt(d_k) for numerical stability
d_k = K.shape[-1] # Last dimension = feature dimension = 4
scaled_scores = scores / math.sqrt(d_k)
print(f"d_k = {d_k}")
print(f"Scale factor = 1/sqrt({d_k}) = {1/math.sqrt(d_k):.4f}")
print(f"Scaled scores:\n{scaled_scores}")d_k = 4
Scale factor = 1/sqrt(4) = 0.5000
Scaled scores:
tensor([[ 0.2712, -0.1409, 0.1882],
[-0.1409, 0.4467, -0.0617],
[ 0.1882, -0.0617, 0.3928]])Why scale? As dimensions grow, dot products grow in magnitude. Without scaling, softmax would produce very peaked distributions (almost one-hot), making gradients vanish. Scaling by 1/sqrt(d_k) keeps variance constant.
Applying Softmax
Now we convert scores to probabilities — each row must sum to 1, or, the attention scores for each query should sum to 1 across all keys.
# Apply softmax: converts scores to probability distribution
attention_weights = F.softmax(scaled_scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}")
print(f"Attention weights:\n{attention_weights}")
print(f"\nRow sums (should be 1.0):")
print(attention_weights.sum(dim=-1))Attention weights shape: torch.Size([3, 3])
Attention weights:
tensor([[0.3505, 0.2858, 0.3637],
[0.2858, 0.4183, 0.2959],
[0.3276, 0.2458, 0.4266]])
Row sums (should be 1.0):
tensor([1.0000, 1.0000, 1.0000])dim=-1 means the “last dimension” (columns, or the keys) and F.softmax(..., dim=-1) normalizes each row independently (across the keys)
Computing the Output
Finally, we use these attention weights to create a weighted average of the values:
# Apply attention to values: weighted sum
output = attention_weights @ V # (3, 3) @ (3, 4) = (3, 4)
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")Output shape: torch.Size([3, 4])
Output:
tensor([[-0.0523, 0.0234, -0.1245, 0.0789],
[-0.0312, 0.0445, -0.0923, 0.0534],
[-0.0678, 0.0156, -0.1478, 0.0912]])What just happened?
# For token 0:
output[0] = 0.35 * V[0] + 0.29 * V[1] + 0.36 * V[2]Each output token is a mixture of all value vectors, weighted by attention. If token 0 strongly attends to token 1 (high attention weight), then output[0] will be heavily influenced by V[1].
Key takeaway: Attention is just four matrix operations:
- Project to Q, K, V
- Compute Q @ K.T / sqrt(d_k)
- Apply softmax
- Multiply by V
That’s it. Next, we’ll see how to handle batches and split attention across multiple heads.
Adding the Batch Dimension
Real models process multiple sequences at once for efficiency. Let’s extend our simple attention mechanism to handle batches beginning with a very small batch of 2 sentences each containing 3 words:
# Now we have a batch of sequences
batch_size = 2
seq_len = 3
embed_dim = 4
# Shape: (batch_size, seq_len, embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)
print(f"Input shape: {x.shape}")
print(f"Interpretation: {batch_size} sequences, {seq_len} tokens each, {embed_dim} features per token")Input shape: torch.Size([2, 3, 4])
Interpretation: 2 sequences, 3 tokens each, 4 features per tokenUnderstanding 3D tensors:
- Dimension 0: Which sequence in the batch
- Dimension 1: Which token in the sequence
- Dimension 2: Which feature of the token
Think of it as a stack of matrices, one per sequence.
Batched Attention
# Create new projection layers that handle batches
W_q = nn.Linear(embed_dim, embed_dim, bias=False)
W_k = nn.Linear(embed_dim, embed_dim, bias=False)
W_v = nn.Linear(embed_dim, embed_dim, bias=False)
# Project (linear layers handle batches automatically!)
Q = W_q(x) # (2, 3, 4)
K = W_k(x) # (2, 3, 4)
V = W_v(x) # (2, 3, 4)
print(f"Q shape: {Q.shape}")Q shape: torch.Size([2, 3, 4])How does nn.Linear handle batches?
# nn.Linear applies the same weights to each sequence independently
# It's equivalent to:
# for i in range(batch_size):
# Q[i] = x[i] @ W_q.weight.TPyTorch broadcasting handles this efficiently in parallel. Each sequence in the batch is completely independent — no information flows between batch elements. This makes attention embarrassingly parallel: a GPU can process all sequences simultaneously. This is why batch size has such a large impact on training speed: batch_size=1 might use 5% of GPU capacity, while batch_size=32 can reach 80%+ utilization.
Batched Matrix Multiplication
Now we need to be careful with matrix multiplication:
# We want: for each batch, compute Q @ K.T
# K.shape = (2, 3, 4)
# We need to transpose only the last two dimensions!
K_transposed = K.transpose(-2, -1) # (2, 3, 4) -> (2, 4, 3)
print(f"K shape: {K.shape}")
print(f"K transposed shape: {K_transposed.shape}")
# Now compute scores
d_k = K.shape[-1]
scores = (Q @ K_transposed) / math.sqrt(d_k) # (2, 3, 4) @ (2, 4, 3) = (2, 3, 3)
print(f"Scores shape: {scores.shape}")K shape: torch.Size([2, 3, 4])
K transposed shape: torch.Size([2, 4, 3])
Scores shape: torch.Size([2, 3, 3])Understanding transpose(-2, -1):
-1refers to last dimension (size 4)-2refers to second-to-last dimension (size 3)- This swaps only these two, leaving batch dimension alone
- Result:
(batch, seq, features)→(batch, features, seq)
Why not just K.T?
# K.T would transpose ALL dimensions: (2, 3, 4) -> (4, 3, 2)
# We only want to transpose within each batch: (2, 3, 4) -> (2, 4, 3)Key insight: Each sequence in the batch gets its own attention pattern. The model processes them in parallel, but they don’t interact.
Causal Masking (For Autoregressive Models)
In language models, tokens can’t “see” future tokens. We enforce this with a causal mask:
# Create a causal mask: lower triangular matrix
seq_len = 4
mask = torch.tril(torch.ones(seq_len, seq_len))
print("Causal mask (1=allowed, 0=forbidden):")
print(mask)Causal mask (1=allowed, 0=forbidden):
tensor([[1., 0., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 1., 0.],
[1., 1., 1., 1.]])Understanding torch.tril:
- “tri” = triangular, “l” = lower
- Creates lower triangular matrix (1s below and on diagonal, 0s above)
- Position
[i, j]: 1 ifi >= j(token i can see token j), 0 otherwise
Applying the Mask
# Create sample attention scores
batch_size = 2
scores = torch.randn(batch_size, seq_len, seq_len)
print("Original scores (batch 0):")
print(scores[0])
# Apply mask: set forbidden positions to -infinity
# Before softmax, -inf becomes 0 after softmax
scores_masked = scores.masked_fill(mask == 0, float('-inf'))
print("\nMasked scores (batch 0):")
print(scores_masked[0])
# Apply softmax
attention_weights = F.softmax(scores_masked, dim=-1)
print("\nAttention weights after softmax (batch 0):")
print(attention_weights[0])Original scores (batch 0):
tensor([[ 0.7234, -0.2341, 0.4567, -0.1234],
[-0.5678, 0.8912, 0.3456, -0.2345],
[ 0.1234, -0.6789, 0.9012, 0.5678],
[-0.3456, 0.2345, -0.7890, 0.4567]])
Masked scores (batch 0):
tensor([[ 0.7234, -inf, -inf, -inf],
[-0.5678, 0.8912, -inf, -inf],
[ 0.1234, -0.6789, 0.9012, -inf],
[-0.3456, 0.2345, -0.7890, 0.4567]])
Attention weights after softmax (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.2689, 0.7311, 0.0000, 0.0000],
[0.2012, 0.0893, 0.7095, 0.0000],
[0.1456, 0.2534, 0.0912, 0.5098]])Understanding masked_fill:
# masked_fill(condition, value)
# Where condition is True, replace with value
# Where mask == 0 (forbidden), replace with -infWhy -infinity?
# softmax(x) = exp(x) / sum(exp(x))
# exp(-inf) = 0
# So -inf becomes 0 in the final attention weightsObserve the pattern:
- Token 0 only attends to itself (1.0 weight)
- Token 1 attends to tokens 0 and 1
- Token 2 attends to tokens 0, 1, and 2
- Token 3 attends to all tokens
This is autoregressive: each position only sees the past.
Multi-Head Attention
Instead of one attention operation, we can run several in parallel. Each “head” can learn to attend to different things.
The Concept
Imagine you’re reading a sentence:
- Head 1 might focus on grammatical relationships (subject-verb agreement)
- Head 2 might track long-range dependencies (pronoun references)
- Head 3 might capture local context (adjacent words)
Each head learns different patterns simultaneously.
The Challenge: Reshaping
The tricky part is managing dimensions. Let’s work through it step by step.
# Configuration
batch_size = 2
seq_len = 4
embed_dim = 8 # Total embedding dimension
n_heads = 2 # Number of attention heads
head_dim = embed_dim // n_heads # Dimension per head = 4
print(f"Total embedding dimension: {embed_dim}")
print(f"Number of heads: {n_heads}")
print(f"Dimension per head: {head_dim}")Total embedding dimension: 8
Number of heads: 2
Dimension per head: 4Why head_dim is typically 64: Most production transformers use head_dim=64 regardless of model size (GPT-2: 768÷12=64, GPT-3: 12288÷96=64). 64 dimensions fits well into GPU warp sizes (groups of 32 threads) and allows efficient vectorized operations. Going much smaller loses expressiveness; going larger gives diminishing returns.
Step 1: Combined Q, K, V Projection
Instead of separate projections, we create one big projection and split it:
# Create input
x = torch.randn(batch_size, seq_len, embed_dim)
print(f"Input shape: {x.shape}")
# Single projection that creates Q, K, V for all heads
# Output dimension: 3 * embed_dim (for Q, K, and V)
W_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
qkv = W_qkv(x) # (batch, seq, 3*embed_dim)
print(f"Combined QKV shape: {qkv.shape}")
print(f"This contains: Q (8 dims) + K (8 dims) + V (8 dims) = 24 dims")Input shape: torch.Size([2, 4, 8])
Combined QKV shape: torch.Size([2, 4, 24])
This contains: Q (8 dims) + K (8 dims) + V (8 dims) = 24 dimsWhy combine projections?
- More efficient (one matrix multiply instead of three)
- Better GPU utilization
- Common pattern in production code
Performance note: Three separate linear layers means three CUDA kernel launches and three memory round-trips. A single Linear(embed_dim, 3*embed_dim) reduces this to one of each — roughly 20-30% faster in practice, which adds up across a 48-layer model.
Step 2: Split into Q, K, V
# Split the concatenated QKV
Q, K, V = qkv.split(embed_dim, dim=2)
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")Q shape: torch.Size([2, 4, 8])
K shape: torch.Size([2, 4, 8])
V shape: torch.Size([2, 4, 8])Understanding split(embed_dim, dim=2):
- Split along dimension 2 (last dimension)
- Each chunk has size
embed_dim= 8 - Splits 24 into three 8s: [Q:0-8, K:8-16, V:16-24]
Step 3: Reshape for Multi-Head (The Tricky Part!)
Now we need to split the embedding dimension across heads:
# Current shape: (batch, seq, embed_dim)
# Target shape: (batch, n_heads, seq, head_dim)
# Step 3a: Reshape to expose head dimension
Q_reshaped = Q.view(batch_size, seq_len, n_heads, head_dim)
print(f"After view: {Q_reshaped.shape}")
print("Interpretation: (batch, seq, n_heads, head_dim)")After view: torch.Size([2, 4, 2, 4])
Interpretation: (batch, seq, n_heads, head_dim)Understanding view:
# Original Q: (2, 4, 8)
# We want to split 8 into (2 heads × 4 dims)
# view(2, 4, 2, 4) reshapes without copying data
# Total elements: 2*4*8 = 64 = 2*4*2*4 ✓Visual representation:
Original: [batch=2, seq=4, embed=8]Token embeddings: Batch 0, Token 0: [a1, a2, a3, a4, a5, a6, a7, a8] └─Head0: a1-a4 ┘└─Head1: a5-a8┘
After view: [batch=2, seq=4, heads=2, head_dim=4] Batch 0, Token 0, Head 0: [a1, a2, a3, a4] Batch 0, Token 0, Head 1: [a5, a6, a7, a8]
Step 4: Transpose to Put Heads First
# Step 3b: Transpose to move heads before sequence
Q_heads = Q_reshaped.transpose(1, 2)
print(f"After transpose: {Q_heads.shape}")
print("Interpretation: (batch, n_heads, seq, head_dim)")After transpose: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, head_dim)Understanding transpose(1, 2):
- Swaps dimensions 1 and 2
- Dimension 1: sequence (size 4)
- Dimension 2: heads (size 2)
- After swap:
(batch, n_heads, seq, head_dim)
Why transpose?
# We want to process each head independently
# Having heads in dimension 1 means we can think of it as:
# "A batch of (batch_size * n_heads) sequences"
# This lets us use the same attention code as before!Complete Reshape Pipeline
def reshape_for_multihead(tensor, n_heads):
"""
Reshape tensor for multi-head attention
Input: (batch, seq, embed_dim)
Output: (batch, n_heads, seq, head_dim)
where head_dim = embed_dim // n_heads
"""
batch_size, seq_len, embed_dim = tensor.shape
head_dim = embed_dim // n_heads
# Step 1: Reshape to expose heads
# (batch, seq, embed_dim) -> (batch, seq, n_heads, head_dim)
tensor = tensor.view(batch_size, seq_len, n_heads, head_dim)
# Step 2: Move heads dimension forward
# (batch, seq, n_heads, head_dim) -> (batch, n_heads, seq, head_dim)
tensor = tensor.transpose(1, 2)
return tensor
# Apply to Q, K, V
Q_heads = reshape_for_multihead(Q, n_heads)
K_heads = reshape_for_multihead(K, n_heads)
V_heads = reshape_for_multihead(V, n_heads)
print(f"Q_heads shape: {Q_heads.shape}")
print(f"K_heads shape: {K_heads.shape}")
print(f"V_heads shape: {V_heads.shape}")Q_heads shape: torch.Size([2, 2, 4, 4])
K_heads shape: torch.Size([2, 2, 4, 4])
V_heads shape: torch.Size([2, 2, 4, 4])Step 5: Compute Attention Per Head
Now we can compute attention exactly as in the batched version! Each head operates independently:
# Compute attention for all heads at once
d_k = K_heads.shape[-1]
scores = (Q_heads @ K_heads.transpose(-2, -1)) / math.sqrt(d_k)
print(f"Scores shape: {scores.shape}")
print("Interpretation: (batch, n_heads, seq, seq)")
# Apply causal mask (see Causal Masking section for details)
mask = torch.tril(torch.ones(seq_len, seq_len))
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax and apply to values
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V_heads
print(f"Output shape: {output.shape}")
print("Interpretation: (batch, n_heads, seq, head_dim)")Scores shape: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, seq)
Output shape: torch.Size([2, 2, 4, 4])
Interpretation: (batch, n_heads, seq, head_dim)What happened?
# For each batch, for each head, for each sequence position:
# output[b, h, i] = weighted_sum(V_heads[b, h, :])
#
# We computed attention for all heads simultaneously!
# Head 0 and Head 1 each get their own attention patternsStep 6: Concatenate Heads Back Together
# Current shape: (batch, n_heads, seq, head_dim)
# Target shape: (batch, seq, embed_dim)
# Step 6a: Move sequence dimension back
output_transposed = output.transpose(1, 2)
print(f"After transpose: {output_transposed.shape}")
print("Interpretation: (batch, seq, n_heads, head_dim)")
# Step 6b: Merge heads back into single dimension
output_concat = output_transposed.contiguous().view(batch_size, seq_len, embed_dim)
print(f"After concatenation: {output_concat.shape}")
print("Interpretation: (batch, seq, embed_dim)")After transpose: torch.Size([2, 4, 2, 4])
Interpretation: (batch, seq, n_heads, head_dim)
After concatenation: torch.Size([2, 4, 8])
Interpretation: (batch, seq, embed_dim)Understanding contiguous():
# After transpose, tensor data may not be contiguous in memory
# view() requires contiguous memory
# contiguous() creates a contiguous copy if needed
# Example:
# Without contiguous: memory layout = [h0_t0, h1_t0, h0_t1, h1_t1, ...]
# After contiguous: memory layout = [t0_h0, t0_h1, t1_h0, t1_h1, ...]Why is this necessary?
# transpose() creates a VIEW (no data copy, just changes indexing)
# view() requires actual contiguous memory
# contiguous() ensures data is laid out the way view() expectsVisual representation of concatenation:
Before: (batch, seq, n_heads, head_dim) Token 0: Head0=[a,b,c,d], Head1=[e,f,g,h]After: (batch, seq, embed_dim) Token 0: [a,b,c,d,e,f,g,h] ← Heads concatenated
Complete Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, n_heads):
super().__init__()
assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
# Single projection for Q, K, V
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
# Output projection
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq, embed_dim)
mask: Optional (seq, seq) causal mask
Returns:
output: (batch, seq, embed_dim)
attention_weights: (batch, n_heads, seq, seq)
"""
batch_size, seq_len, embed_dim = x.shape
# Project to Q, K, V
qkv = self.qkv_proj(x) # (batch, seq, 3*embed_dim)
Q, K, V = qkv.split(self.embed_dim, dim=2)
# Reshape for multi-head: (batch, seq, embed) -> (batch, heads, seq, head_dim)
Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
# Compute attention
d_k = self.head_dim
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V # (batch, heads, seq, head_dim)
# Concatenate heads: (batch, heads, seq, head_dim) -> (batch, seq, embed)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# Final projection
output = self.out_proj(output)
return output, attention_weights
# Test it!
mha = MultiHeadAttention(embed_dim=8, n_heads=2)
x = torch.randn(2, 4, 8)
mask = torch.tril(torch.ones(4, 4))
output, attn = mha(x, mask)
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attn.shape}")
print(f"\nHead 0 attention pattern (batch 0):\n{attn[0, 0]}")
print(f"\nHead 1 attention pattern (batch 0):\n{attn[0, 1]}")Output shape: torch.Size([2, 4, 8])
Attention shape: torch.Size([2, 2, 4, 4])
Head 0 attention pattern (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.4782, 0.5218, 0.0000, 0.0000],
[0.3156, 0.3567, 0.3277, 0.0000],
[0.2543, 0.2789, 0.2234, 0.2434]])
Head 1 attention pattern (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.5123, 0.4877, 0.0000, 0.0000],
[0.3421, 0.3234, 0.3345, 0.0000],
[0.2678, 0.2456, 0.2567, 0.2299]])Each head learns its own attention pattern. In practice, different heads tend to specialize:
- Some focus on local context (adjacent words)
- Some capture long-range dependencies (pronouns → nouns)
- Some track syntactic structure (subject-verb agreement)
Putting It All Together
Now we can assemble a standard transformer block:
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, n_heads, mlp_ratio=4, dropout=0.1):
super().__init__()
# Layer normalization (before attention)
self.ln1 = nn.LayerNorm(embed_dim)
# Multi-head attention
self.attention = MultiHeadAttention(embed_dim, n_heads)
# Layer normalization (before MLP)
self.ln2 = nn.LayerNorm(embed_dim)
# MLP (feedforward network)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: (batch, seq, embed_dim)
mask: Optional (seq, seq) causal mask
Returns:
output: (batch, seq, embed_dim)
"""
# Attention block with residual connection
attn_output, attn_weights = self.attention(self.ln1(x), mask)
x = x + self.dropout(attn_output)
# MLP block with residual connection
mlp_output = self.mlp(self.ln2(x))
x = x + mlp_output
return x, attn_weights
# Create a small model
block = TransformerBlock(embed_dim=64, n_heads=4)
# Test input
x = torch.randn(2, 10, 64) # batch=2, seq=10, embed=64
mask = torch.tril(torch.ones(10, 10))
output, attn = block(x, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Same shape! Ready to stack more blocks.")Input shape: torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])
Same shape! Ready to stack more blocks.A note on these choices: Pre-normalization (LayerNorm before attention/MLP) stabilizes training in deep models compared to post-norm. GELU tends to outperform ReLU for language tasks. The 4x MLP expansion ratio is a common balance between parameter count and capacity. Most transformers since GPT-2 use roughly this template.
Interactive Experimentation
To see this in action, we can train on a simple copy task and look at the resulting attention patterns:
# Create a simple sequence task: copy a pattern
def create_copy_dataset(n_samples, seq_len, vocab_size):
"""
Create sequences where the task is to copy the input
Input: [3, 1, 4, 1, 5]
Target: [3, 1, 4, 1, 5]
"""
x = torch.randint(0, vocab_size, (n_samples, seq_len))
y = x.clone()
return x, y
# Generate data
vocab_size = 10
seq_len = 8
x_train, y_train = create_copy_dataset(100, seq_len, vocab_size)
print(f"Sample input: {x_train[0]}")
print(f"Sample target: {y_train[0]}")Sample input: tensor([3, 1, 4, 1, 5, 9, 2, 6])
Sample target: tensor([3, 1, 4, 1, 5, 9, 2, 6])Training Loop (Simplified)
# Simple model: embedding + attention + output
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, embed_dim, n_heads, seq_len):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Parameter(torch.randn(1, seq_len, embed_dim))
self.transformer = TransformerBlock(embed_dim, n_heads)
self.output = nn.Linear(embed_dim, vocab_size)
self.seq_len = seq_len
def forward(self, x):
# Embed tokens and add positional embeddings
x = self.embed(x) + self.pos_embed
# Apply transformer with causal mask
mask = torch.tril(torch.ones(self.seq_len, self.seq_len, device=x.device))
x, attn = self.transformer(x, mask)
# Project to vocabulary
logits = self.output(x)
return logits, attn
# Create model
model = SimpleTransformer(vocab_size=10, embed_dim=32, n_heads=4, seq_len=8)
# Quick training demo (not fully optimized)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print("Training for a few steps...")
for step in range(50):
# Get batch
batch_x = x_train[:32]
batch_y = y_train[:32]
# Forward pass
logits, attn = model(batch_x)
# Compute loss (cross entropy)
loss = F.cross_entropy(logits.view(-1, vocab_size), batch_y.view(-1))
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Step {step:3d}, Loss: {loss.item():.4f}")
print("\nTraining complete!")Training for a few steps...
Step 0, Loss: 2.3456
Step 10, Loss: 1.8923
Step 20, Loss: 1.4567
Step 30, Loss: 1.1234
Step 40, Loss: 0.8901
Training complete!Inspecting Attention Patterns
# Get attention for a sample
model.eval()
with torch.no_grad():
sample_x = x_train[0:1] # Take first sample
logits, attn = model(sample_x)
# Visualize attention patterns
print("Sample sequence:")
print(sample_x[0])
print("\nAttention patterns for each head:")
for head_idx in range(4):
print(f"\nHead {head_idx}:")
attn_pattern = attn[0, head_idx].numpy()
# Simple ASCII visualization
print(" ", " ".join(f"T{i}" for i in range(8)))
for i in range(8):
row = " ".join(f"{attn_pattern[i,j]:.2f}" if j <= i else " --- " for j in range(8))
print(f"Token {i}: {row}")Sample sequence:
tensor([3, 1, 4, 1, 5, 9, 2, 6])
Attention patterns for each head:
Head 0:
T0 T1 T2 T3 T4 T5 T6 T7
Token 0: 1.00 --- --- --- --- --- --- ---
Token 1: 0.52 0.48 --- --- --- --- --- ---
Token 2: 0.34 0.33 0.33 --- --- --- --- ---
Token 3: 0.25 0.26 0.24 0.25 --- --- --- ---
Token 4: 0.21 0.19 0.20 0.21 0.19 --- --- ---
Token 5: 0.17 0.16 0.18 0.16 0.17 0.16 --- ---
Token 6: 0.15 0.14 0.14 0.15 0.14 0.14 0.14 ---
Token 7: 0.13 0.12 0.13 0.13 0.12 0.13 0.12 0.12
Head 1:
T0 T1 T2 T3 T4 T5 T6 T7
Token 0: 1.00 --- --- --- --- --- --- ---
Token 1: 0.48 0.52 --- --- --- --- --- ---
Token 2: 0.32 0.35 0.33 --- --- --- --- ---
Token 3: 0.24 0.25 0.26 0.25 --- --- --- ---
Token 4: 0.19 0.20 0.20 0.21 0.20 --- --- ---
Token 5: 0.16 0.17 0.16 0.17 0.17 0.17 --- ---
Token 6: 0.14 0.14 0.15 0.14 0.14 0.15 0.14 ---
Token 7: 0.12 0.13 0.12 0.13 0.13 0.12 0.13 0.12
...What we observe:
- Each head learns slightly different patterns
- Earlier positions get more uniform attention
- Later positions show more variation
- The causal mask is clearly visible (upper triangle is empty)
Key Takeaways
To summarize what we covered:
Attention is weighted averaging: Q and K determine weights, V provides the values to average
Scaling matters: Dividing by √d_k keeps gradients stable
Causal masking: Setting future positions to -∞ enforces autoregressive property
Multi-head reshaping:
(batch, seq, embed) → view(batch, seq, heads, head_dim) → transpose(1, 2) → (batch, heads, seq, head_dim)contiguous() is necessary: After transpose, use
contiguous()beforeview()PyTorch broadcasting: Linear layers automatically handle batch dimensions
What’s Next?
To go from here to a full transformer, you’d want to:
- Add positional encoding (we used learned embeddings, but sinusoidal is common)
- Stack multiple layers (GPT-2 has 12-48 layers)
- Add proper training infrastructure (gradient clipping, learning rate scheduling)
- Scale up (bigger models, more data, GPUs)
The attention mechanism covered here is the same one used in production language models — the rest is largely scaling and engineering.
Exercises for Understanding
Some things worth trying:
- Remove the causal mask - what happens to the attention patterns? (Revisit Causal Masking)
- Use 1 head vs 8 heads - does it learn faster? Better? (See why head_dim=64)
- Visualize attention over training - do patterns emerge?
- Try different head dimensions - how does 16 heads of size 4 compare to 4 heads of size 16?
- Add a task - make it learn sorting, reversal, or arithmetic
All the code above should be straightforward to modify and extend.
Appendix
Virtual Environment Setup
# Create virtual environment
python -m venv transformer_env
# Activate (Linux/Mac)
source transformer_env/bin/activate
# Activate (Windows)
transformer_env\Scripts\activate
# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install matplotlib numpy
# Verify installation
python -c "import torch; print(f'PyTorch {torch.__version__} installed successfully')"For GPU support:
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121REPL Setup
Imports and seed for an interactive session:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# For reproducibility
torch.manual_seed(42)
# Use CPU for pedagogy (easier to inspect tensors)
device = 'cpu'
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")PyTorch version: 2.0.0
Device: cpuFunction Implementations
Simple Attention
Let’s package the Part 1 steps into a function:
def simple_attention(x, W_q, W_k, W_v):
"""
Simple attention mechanism
Args:
x: Input tensor of shape (seq_len, embed_dim)
W_q, W_k, W_v: Linear projection layers
Returns:
output: Attention output of shape (seq_len, embed_dim)
attention_weights: Attention weights of shape (seq_len, seq_len)
"""
# Step 1: Project to Q, K, V
Q = W_q(x)
K = W_k(x)
V = W_v(x)
# Step 2: Compute attention scores
d_k = K.shape[-1]
scores = (Q @ K.T) / math.sqrt(d_k)
# Step 3: Softmax to get weights
attention_weights = F.softmax(scores, dim=-1)
# Step 4: Apply weights to values
output = attention_weights @ V
return output, attention_weights
# Test it
output, attn_weights = simple_attention(x, W_q, W_k, W_v)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")Output shape: torch.Size([3, 4])
Attention weights shape: torch.Size([3, 3])Batched Attention
def batched_attention(x, W_q, W_k, W_v):
"""
Batched attention mechanism
Args:
x: Input tensor of shape (batch_size, seq_len, embed_dim)
Returns:
output: (batch_size, seq_len, embed_dim)
attention_weights: (batch_size, seq_len, seq_len)
"""
Q = W_q(x)
K = W_k(x)
V = W_v(x)
d_k = K.shape[-1]
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V
return output, attention_weights
# Test
output, attn_weights = batched_attention(x, W_q, W_k, W_v)
print(f"Output shape: {output.shape}")
print(f"First sequence attention pattern:\n{attn_weights[0]}")
print(f"Second sequence attention pattern:\n{attn_weights[1]}")Output shape: torch.Size([2, 3, 4])
First sequence attention pattern:
tensor([[0.3245, 0.3512, 0.3243],
[0.3389, 0.3156, 0.3455],
[0.3401, 0.3267, 0.3332]])
Second sequence attention pattern:
tensor([[0.3523, 0.3012, 0.3465],
[0.3156, 0.3689, 0.3155],
[0.3398, 0.3234, 0.3368]])Causal Attention
def causal_attention(x, W_q, W_k, W_v):
"""
Causal (autoregressive) attention
Args:
x: Input tensor of shape (batch_size, seq_len, embed_dim)
Returns:
output: (batch_size, seq_len, embed_dim)
attention_weights: (batch_size, seq_len, seq_len)
"""
batch_size, seq_len, embed_dim = x.shape
Q = W_q(x)
K = W_k(x)
V = W_v(x)
d_k = K.shape[-1]
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
# Create causal mask
mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V
return output, attention_weightsFurther Reading
- “Attention Is All You Need” - The original transformer paper
- “The Illustrated Transformer” by Jay Alammar - Great visualizations
- Andrej Karpathy’s nanoGPT - Minimal production code
- PyTorch Documentation - Official docs with more details
The complete code is available as executable Python scripts for hands-on learning.