Back to Project: Building GPT10.4

Inference & Generation

Load a trained checkpoint, control the sampling, and generate text.

Training taught the model to predict the next token. To generate text, we repeat that prediction, extending the sequence one token at a time. In generate.py, we load a saved checkpoint, implement this loop, add sampling controls, and see what the model generates.

Load the Checkpoint

To generate text, we first reconstruct the trained model from its checkpoint: load the config, restore the weights, and switch to eval mode.

# generate.py
import torch
import tiktoken
 
from config import GPTConfig, get_device
from model import GPT
 
 
device = get_device()
checkpoint = torch.load("out/ckpt.pt", map_location=device)
 
config = GPTConfig(**checkpoint["config"])
model = GPT(config).to(device)
model.load_state_dict(checkpoint["model"])
model.eval()
 
enc = tiktoken.get_encoding("gpt2")

model.eval() disables dropout, which is only needed during training as regularization. Since generation never updates weights, we also turn off gradient tracking with @torch.inference_mode() on the generation method in the next step.

Implement generate()

Generation works one token at a time: forward pass, take the logits at the last position, pick a token, append it, and repeat. The simplest way to pick a token is greedy decoding, which always takes the highest-scoring one. We implement this as a method on the model class since it uses the forward pass and context length directly.

# model.py
import torch
 
 
class GPT(nn.Module):
    ...
    @torch.inference_mode()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.context_length :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            idx = torch.cat((idx, next_token), dim=1)
 
        return idx

As the sequence grows, it can exceed the model's context window, so we crop the input to context_length tokens each iteration. Only the logits at the final position predict the next token, so we take logits[:, -1, :] and pick from there.

Control the Sampling

Since greedy decoding always picks the highest-scoring token, the output is deterministic and often repetitive. To introduce controlled randomness, we extend generate() with two controls: temperature and top-k.

Temperature divides all logits by a single value before softmax. Dividing by a value below 1 amplifies the gaps between scores, so softmax pushes most probability toward the top few tokens. Dividing by a value above 1 shrinks the gaps, spreading probability more evenly and giving lower-scoring tokens a real chance. At 0, we skip softmax entirely and fall back to argmax.

Top-k keeps only the k highest-scoring logits and sets everything else to negative infinity. After softmax, those masked positions get zero probability, so the model only samples from the top k candidates.

# model.py
import torch
import torch.nn.functional as F
 
 
class GPT(nn.Module):
    ...
    @torch.inference_mode()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.context_length :]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
 
            if temperature == 0.0:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                logits = logits / temperature
 
                if top_k is not None:
                    values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < values[:, [-1]]] = -float("inf")
 
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
 
            idx = torch.cat((idx, next_token), dim=1)
 
        return idx

Back in generate.py, we encode a prompt into token IDs, call model.generate() with temperature and top-k values, and decode the output back to text.

# generate.py
prompt = "Machine learning is"
idx = torch.tensor([enc.encode(prompt)], device=device)
 
out = model.generate(
    idx,
    max_new_tokens=200,
    temperature=0.8,
    top_k=50,
)
 
# out is (1, seq) since we generate a single sequence; grab the first row
print(enc.decode(out[0].tolist()))

If changing these values makes no noticeable difference, or if every setting produces equally broken output, the checkpoint is likely still undertrained.

Milestone

generate.py can now load a trained checkpoint, encode a prompt, and generate text from it.

The pretraining pipeline is now complete: raw text becomes tokens, the model trains on those tokens, and the trained checkpoint generates text from a prompt.

Turning that base model into a chatbot involves post-training, starting with supervised fine-tuning (SFT), which reuses the same next-token training on instruction-and-response data, often followed by preference tuning such as RLHF. To explore that pipeline, see Karpathy's nanochat.