Skip to main content
Vamsi Cheruku.
Back to Notes
attention2026-05-20

Multi-Head Attention Dimensional Splitting

Tensor transformations, shape changes, and concat projections in multi-head attention systems.

attention multi-head tensors

Multi-Head Attention (MHA) splits the model's embedding dimension into parallel representations, allowing the model to attend to different information subspaces simultaneously.

The Splitting Logic

Instead of calculating single attention over d_model dimension, we project queries, keys, and values h times with different linear projections to size d_k.

Let h be the number of heads, and d_model be the embedding dimension. We enforce:

d_k = d_v = d_model / h

Tensor Shapes Tracking

Here is how the shapes change during the forward pass of MHA:

  1. Input Matrix: X shape: [Batch, SeqLength, d_model]

  2. Linear Projection: Project X to get queries Q, keys K, and values V of shape [Batch, SeqLength, d_model].

  3. Split into Heads: Reshape to [Batch, SeqLength, h, d_k]. Transpose dimensions to [Batch, h, SeqLength, d_k] to run head calculations in parallel.

  4. Attention Computation: Calculate attention weights: [Batch, h, SeqLength, SeqLength]. Multiply by values to get context: [Batch, h, SeqLength, d_k].

  5. Concatenation: Transpose back to [Batch, SeqLength, h, d_k]. Flatten last two dimensions: [Batch, SeqLength, d_model].

  6. Output Projection: Pass through final linear layer W_o of shape [d_model, d_model] to mix heads output: [Batch, SeqLength, d_model].

PyTorch Implementation

import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.h = num_heads
        self.d_k = d_model // num_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        B, T, C = x.size() # Batch, SeqLength, d_model
        
        # Project and reshape to [B, h, T, d_k]
        q = self.w_q(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(B, T, self.h, self.d_k).transpose(1, 2)
        
        # Scale dot-product attention
        scores = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        # Apply causal mask here if needed...
        weights = torch.softmax(scores, dim=-1)
        context = weights @ v # Shape: [B, h, T, d_k]
        
        # Concatenate and project output
        context = context.transpose(1, 2).contiguous().view(B, T, C)
        return self.w_o(context)

Share Reference Sheet