This chapter makes the model learn. In train.py, we choose an optimizer and a learning-rate schedule, run a single update step, expand it into a full training loop with validation, and add checkpointing so training can be paused and resumed.
Set Up Optimization
After each forward pass, backpropagation computes a gradient for every weight: a signal that says which direction to adjust it and by how much to reduce the loss. The optimizer uses these gradients to update the weights. The simplest update rule is w = w - lr * gradient.
AdamW improves on this by adapting the step size per weight based on its gradient history: weights with large, noisy gradients get smaller steps to avoid overshooting, while weights with small, steady gradients get larger ones. It also applies weight decay, shrinking all weights slightly toward zero after each update to keep the model from relying too heavily on any single parameter.
The learning rate lr controls how much weights change per step. Rather than holding it fixed, we vary it over training with a schedule. Linear warmup ramps the learning rate from near zero up to the peak over the first few hundred steps, giving the optimizer time to calibrate before making large updates. Cosine decay then smoothly reduces it to one-tenth of the peak. As the weights improve, they need finer updates to keep learning without undoing previous progress.
The peak learning rate is the main hyperparameter to get right. The practical workflow is to start from a published rate at your model size, then sweep nearby values in short runs. For reference, GPT-3 used 6e-4 for its 125M model, 3e-4 for 350M, 2e-4 for 1.3B, and 6e-5 for 175B. Pythia's smaller models follow the same trend: 1e-3 for 70M, 6e-4 for 160M, 3e-4 for 410M. Once you have a starting point, sweep nearby rates for a few hundred steps and pick the one where loss falls fastest without becoming unstable.
Warmup keeps the learning rate low at the start while Adam calibrates its per-weight estimates from real gradients. Without it, the full rate hits before those estimates are reliable, so early updates are large and poorly aimed. Like the learning rate, the warmup length is typically borrowed from a similar training run. If the loss becomes unstable as the rate approaches its peak, try extending the warmup, but going much longer means fewer steps at the peak rate where most learning happens.
Each training step computes a gradient from a batch of tokens. The size of that batch is the next choice to get right. With too few tokens, the gradient reflects noise from the specific batch more than the actual patterns in the data. With too many, the noise has already been averaged out and further improvements cost disproportionately more compute. For this model size, the threshold is around 500,000 tokens per step. We use 524,288.
GPT-3 used 0.5M tokens per batch for its 125M model, informed by the critical batch size framework (McCandlish et al., 2018), which measures where gradient noise stops being the bottleneck. Most reproductions at this scale, including nanoGPT and llm.c, adopted the same value.
Typically, an optimizer step needs more tokens than a single forward pass can fit in memory. Gradient accumulation solves this by splitting each step into multiple smaller forward passes. Each pass processes a micro-batch and computes gradients without updating the weights. After grad_accum_steps passes, the optimizer applies a single update using the accumulated gradients.
The micro-batch size, the number of sequences per forward pass, depends on available memory. With a micro-batch of 16 samples and a context length of 1024, each pass processes 16,384 tokens. Reaching 524,288 per step takes 32 passes. Across the full dataset of roughly 2.5 billion tokens (the Chinchilla recommendation of 20 tokens per parameter), training takes about 4,768 of these steps. That determines max_steps and the window the learning-rate schedule decays over.
weight_decay=0.1 and betas=(0.9, 0.95) are common decoder-only pretraining defaults, used in GPT-3 and OLMo among others. The β2 of 0.95 gives Adam a shorter memory for gradient variance than PyTorch's default of 0.999, making it adapt faster as training progresses. Both values matter less to tune than the learning rate and batch size.
# train.py
import math
import torch
total_batch_size = 524_288
micro_batch_size = 16
grad_accum_steps = total_batch_size // (micro_batch_size * config.context_length)
max_lr = 6e-4
min_lr = 6e-5
warmup_steps = 200
max_steps = 4_768
weight_decay = 0.1
def get_lr(step: int) -> float:
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
if step >= max_steps:
return min_lr
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=max_lr,
betas=(0.9, 0.95),
weight_decay=weight_decay,
)For the lighter model (~30M parameters), we follow llama2.c's training recipe: total_batch_size=131_072, max_lr=5e-4, min_lr=5e-5. That gives 8 accumulation steps at micro-batch 16. From the data budget of ~600M tokens (20 per parameter), training takes roughly 4,600 steps. The warmup, weight decay, betas, and remaining settings stay the same.
Run One Training Step
A single training step processes total_batch_size tokens and produces one weight update. If you understand this block of code, the full loop is just repetition around it.
# train.py
model.train()
optimizer.zero_grad()
train_loss = 0.0
for micro_step in range(grad_accum_steps):
x, y = get_batch("train", batch_size=micro_batch_size, seq_len=config.context_length, device=device)
_, loss = model(x, y)
loss = loss / grad_accum_steps
train_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()model.train() enables training behavior like dropout, and zero_grad() clears gradients from the previous step. The loop runs grad_accum_steps micro-batches, each through a forward and backward pass. It divides each loss by grad_accum_steps before backward() so the accumulated gradients average out to the same result as processing all 524,288 tokens at once. After the loop, clip_grad_norm_() caps the gradient norm at 1.0 and optimizer.step() applies the update.
Build the Training Loop
The full loop repeats that training step for max_steps iterations, adjusting the learning rate at each one and running validation every 200 steps to log progress and save checkpoints.
# train.py
eval_interval = 200
eval_batches = 20
best_val_loss = float("inf")
@torch.no_grad()
def estimate_val_loss():
model.eval()
vals = torch.zeros(eval_batches)
for k in range(eval_batches):
x, y = get_batch("val", batch_size=micro_batch_size, seq_len=config.context_length, device=device)
_, loss = model(x, y)
vals[k] = loss.item()
model.train()
return vals.mean().item()
for step in range(max_steps):
lr = get_lr(step)
for group in optimizer.param_groups:
group["lr"] = lr
optimizer.zero_grad()
train_loss = 0.0
for micro_step in range(grad_accum_steps):
x, y = get_batch("train", batch_size=micro_batch_size, seq_len=config.context_length, device=device)
_, loss = model(x, y)
loss = loss / grad_accum_steps
train_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if step % eval_interval == 0:
val_loss = estimate_val_loss()
print(
f"step {step:5d} | "
f"train {train_loss:.4f} | "
f"val {val_loss:.4f} | "
f"lr {lr:.6f}"
)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(step, best_val_loss)estimate_val_loss() switches to eval mode and measures loss on the validation split without affecting training. The @torch.no_grad() decorator tells PyTorch not to build computation graphs during these forward passes, since no weight updates happen during evaluation. Without it, every eval pass would allocate memory for a graph that is never used.
The initial loss should be near ln(50257) ≈ 10.8, the expected value when the model assigns equal probability across GPT-2's vocabulary. It should decrease over the first few hundred steps.
Save and Resume Training
A checkpoint captures the full training state so a run can resume after an interruption or the model can be loaded for generation.
# train.py
from dataclasses import asdict
from pathlib import Path
out_dir = Path("out")
out_dir.mkdir(exist_ok=True)
checkpoint_path = out_dir / "ckpt.pt"
def save_checkpoint(step: int, best_val_loss: float):
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": asdict(config),
"step": step,
"best_val_loss": best_val_loss,
},
checkpoint_path,
)The checkpoint saves the optimizer state alongside the model weights. AdamW builds up momentum and variance estimates that tune its step size to each parameter individually, and losing that history would force it to recalibrate from scratch.
# train.py — resume path
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
config = GPTConfig(**checkpoint["config"])
model = GPT(config).to(device)
model.load_state_dict(checkpoint["model"])
optimizer = torch.optim.AdamW(
model.parameters(),
lr=max_lr,
betas=(0.9, 0.95),
weight_decay=weight_decay,
)
optimizer.load_state_dict(checkpoint["optimizer"])
start_step = checkpoint["step"] + 1
best_val_loss = checkpoint["best_val_loss"]The optimizer is constructed with the model's parameters first, so load_state_dict can map the saved momentum and variance estimates back to each parameter. The training loop then starts from range(start_step, max_steps) instead of range(max_steps).
train.py can now train the model end to end and save checkpoints along the way.
In the next chapter, we load a checkpoint and generate text from the trained model.