Autoregressive language models (like GPT, LLaMA, Mistral) use a decoder-only architecture. This differs from the original encoder-decoder structure by removing the cross-attention layer and applying causal masking.
Block Structure
A single GPT decoder block contains:
- Pre-Layer Normalization (Pre-LN): Norm applied to input before self-attention.
- Causal Multi-Head Self-Attention: Evaluates token relationships, masked from looking ahead.
- Residual Connection: Adds block input to attention output.
- Pre-Layer Normalization (Pre-LN): Norm applied to output of residual add.
- Feed-Forward Network (FFN): Two linear layers with an activation function (like GELU or SwiGLU) in between.
- Residual Connection: Adds FFN input to FFN output.
Input ➔ LayerNorm ➔ Causal Attention ➔ (+) ➔ LayerNorm ➔ FFN ➔ (+) ➔ Output
│ ▲ │ ▲
└─────────────────────────────────────┘ └──────────────────┘
Parameter Configurations (GPT-2 standards)
Typical scale hyperparameters for autoregressive decoders:
| Parameter | Symbol | GPT-2 Small | GPT-2 Medium | Description |
|---|---|---|---|---|
| Layers | n_layer | 12 | 24 | Number of transformer block stacks |
| Dimension | d_model | 768 | 1024 | Hidden state vector size |
| Heads | n_head | 12 | 16 | Number of attention heads |
| Block Size | n_ctx | 1024 | 1024 | Max context window tokens |
| Vocab Size | n_vocab | 50257 | 50257 | Total size of BPE vocabulary |
PyTorch Block Definition
import torch
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self, d_model):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(0.1)
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
# self.attn = MultiHeadAttention(d_model, n_head)
self.ln2 = nn.LayerNorm(d_model)
self.ffwd = FeedForward(d_model)
def forward(self, x):
# Pre-LN architecture (modern standard)
# x = x + self.attn(self.ln1(x))
# x = x + self.ffwd(self.ln2(x))
return x