Skip to main content
Vamsi Cheruku.
Back to Notes
attention2026-05-19

Causal Masking and Attention Bounds

How to prevent looking ahead during training using a lower-triangular causal mask matrix.

attention masking gpt

In autoregressive decoder-only models, causal masking prevents queries at sequence position t from attending to keys at positions greater than t.

Mathematical Formulation

Let the attention score matrix before softmax be S = (Q * K^T) / sqrt(d_k). The causal mask matrix M of shape [T, T] is defined as:

M_ij = 0        if i >= j
M_ij = -inf     if i < j

The masked attention weights are calculated as:

A_ij = Softmax(S_ij + M_ij)

When M_ij = -inf, exp(S_ij - inf) = 0, making the softmax probability A_ij = 0 for all future tokens j > i.

PyTorch Tensor Operation

Here is how the causal mask is applied inside a PyTorch module:

import torch
import torch.nn.functional as F
 
# B = batch size, H = heads, T = seq length, D = head dimension
q = torch.randn(2, 8, 1024, 64)
k = torch.randn(2, 8, 1024, 64)
 
# Calculate scores
scores = q @ k.transpose(-2, -1) * (1.0 / (64 ** 0.5))
 
# Create upper triangular mask matrix
tril = torch.tril(torch.ones(1024, 1024))
# Shift to device and apply masked fill
masked_scores = scores.masked_fill(tril == 0, float('-inf'))
 
# Softmax converts -inf elements to 0.0
weights = F.softmax(masked_scores, dim=-1)

Share Reference Sheet