Multi-Head Attention (MHA) splits the model's embedding dimension into parallel representations, allowing the model to attend to different information subspaces simultaneously.
The Splitting Logic
Instead of calculating single attention over d_model dimension, we project queries, keys, and values h times with different linear projections to size d_k.
Let h be the number of heads, and d_model be the embedding dimension. We enforce:
d_k = d_v = d_model / h
Tensor Shapes Tracking
Here is how the shapes change during the forward pass of MHA:
-
Input Matrix:
Xshape:[Batch, SeqLength, d_model] -
Linear Projection: Project
Xto get queriesQ, keysK, and valuesVof shape[Batch, SeqLength, d_model]. -
Split into Heads: Reshape to
[Batch, SeqLength, h, d_k]. Transpose dimensions to[Batch, h, SeqLength, d_k]to run head calculations in parallel. -
Attention Computation: Calculate attention weights:
[Batch, h, SeqLength, SeqLength]. Multiply by values to get context:[Batch, h, SeqLength, d_k]. -
Concatenation: Transpose back to
[Batch, SeqLength, h, d_k]. Flatten last two dimensions:[Batch, SeqLength, d_model]. -
Output Projection: Pass through final linear layer
W_oof shape[d_model, d_model]to mix heads output:[Batch, SeqLength, d_model].
PyTorch Implementation
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.h = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.size() # Batch, SeqLength, d_model
# Project and reshape to [B, h, T, d_k]
q = self.w_q(x).view(B, T, self.h, self.d_k).transpose(1, 2)
k = self.w_k(x).view(B, T, self.h, self.d_k).transpose(1, 2)
v = self.w_v(x).view(B, T, self.h, self.d_k).transpose(1, 2)
# Scale dot-product attention
scores = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
# Apply causal mask here if needed...
weights = torch.softmax(scores, dim=-1)
context = weights @ v # Shape: [B, h, T, d_k]
# Concatenate and project output
context = context.transpose(1, 2).contiguous().view(B, T, C)
return self.w_o(context)