Next-Token Prediction¶
Before fine-tuning, it is essential to understand what a language model actually learns during pre-training. The core objective is next-token prediction (NTP): given a sequence of tokens, predict the next one. Everything from GPT-2 to Llama 3 is trained with this single principle.
The Probability Model¶
A language model assigns a probability to every possible sequence of tokens \(x_1, x_2, \ldots, x_T\). Using the chain rule of probability, any joint distribution can be factored autoregressively:
A neural language model with parameters \(\theta\) approximates each conditional:
This is a categorical distribution over the vocabulary \(\mathcal{V}\)(e.g., 32,000 tokens for Llama). The model outputs a vector of logits \(\mathbf{z}_t \in \mathbb{R}^{|\mathcal{V}|}\), which are converted to probabilities via softmax:
Autoregressive generation: At inference time, the model generates token by token—each new token is appended to the context and fed back in to predict the next:
The Training Objective: Cross-Entropy Loss¶
Given a training corpus of documents, each document is treated as a sequence of tokens. The model is trained to maximize the log-likelihood of observed tokens, which is equivalent to minimizing the cross-entropy loss.
For a single document of length \(T\):
Each term \(-\log P_\theta(x_t \mid x_{<t})\) measures how surprised the model is when it sees the actual next token \(x_t\). A perfect model would assign probability 1 to the correct token, giving a loss of 0.
Over a dataset \(\mathcal{D} = \{d^{(1)}, d^{(2)}, \ldots, d^{(N)}\}\) of \(N\) documents:
Connection to Perplexity¶
Perplexity (PPL) is the standard evaluation metric for language models and is directly tied to the NTP loss:
Intuitively, a perplexity of $k $ means the model is "as confused as if choosing uniformly among\(k\) options" at each step. Lower is better.
Contrast: Masked Encoder Training¶
Autoregressive next-token prediction is the standard objective for decoder-only models such as GPT and Llama. By contrast, encoder-only models such as BERT are usually trained with masked language modeling (MLM): randomly hide a subset of tokens, then predict the missing tokens from both left and right context.
If \(\mathcal{M}\) is the set of masked positions, the MLM loss is:
The key difference is architectural:
- Causal NTP uses a triangular mask and only looks left, so it supports text generation naturally
- Masked encoder training uses bidirectional context, so it learns strong contextual representations for classification, retrieval, and token labeling
- Encoder models are excellent feature extractors, but they are not usually used as standalone autoregressive generators
So when you hear "masked encoder training," think representation learning with missing-token reconstruction, not step-by-step text generation.
Why Does This Work?¶
Training on next-token prediction on large corpora forces the model to:
- Learn syntax and grammar — token sequences must be grammatically plausible
- Learn factual knowledge — predicting "The capital of France is ___" requires knowing "Paris"
- Learn reasoning patterns — math or logic examples appear in text and must be predicted correctly
- Learn long-range dependencies — the Transformer's attention lets each prediction attend to all prior tokens
This is why a model pre-trained purely on NTP can then be fine-tuned for specific tasks with relatively few examples.
Training Next-Token Prediction Loss from Scratch¶
Let us implement the NTP loss manually to build intuition before using high-level trainers. If you are beginner, we suggest you skip this part and directly use the transformer package in the next part.
Minimal example with pure PyTorch¶
import torch
import torch.nn.functional as F
# Toy example
# Suppose we have a tiny vocabulary of 5 tokens and a sequence of length 4
# tokens: [2, 0, 3, 1] → input = [2, 0, 3], target = [0, 3, 1]
vocab_size = 5
seq_len = 3 # we predict 3 positions
# Simulated logits from the model (shape: [seq_len, vocab_size])
torch.manual_seed(42)
logits = torch.randn(seq_len, vocab_size)
# The correct next tokens for each position
targets = torch.tensor([0, 3, 1]) # shape: [seq_len]
# Cross-entropy loss: equivalent to -log P(correct token)
# F.cross_entropy applies log-softmax internally
loss_per_token = F.cross_entropy(logits, targets, reduction="none")
print("Per-token losses:", loss_per_token)
loss = loss_per_token.mean()
print(f"Average NTP loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
Expected output (deterministic with torch.manual_seed(42)):
Shift-by-one: input vs. target in practice¶
The key implementation detail: the target at position \(t\) is the input at position\(t+1\). This is done by shifting the token sequence by one.
import torch
import torch.nn.functional as F
def ntp_loss(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
"""
Compute causal language modeling (NTP) loss.
Args:
logits: Model output, shape [batch, seq_len, vocab_size]
input_ids: Token IDs, shape [batch, seq_len]
Returns:
Scalar loss (mean cross-entropy over all non-padding positions)
"""
# Shift: predict position t+1 using logits at position t
# logits at positions 0..T-2 should predict tokens at positions 1..T-1
shift_logits = logits[:, :-1, :].contiguous() # [B, T-1, V]
shift_labels = input_ids[:, 1:].contiguous() # [B, T-1]
# Flatten batch and time dimensions for cross_entropy
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), # [B*(T-1), V]
shift_labels.view(-1), # [B*(T-1)]
)
return loss
# Demo with a batch of 2 sequences of length 6
torch.manual_seed(0)
B, T, V = 2, 6, 32000 # batch, seq_len, vocab_size
dummy_logits = torch.randn(B, T, V)
dummy_input_ids = torch.randint(0, V, (B, T))
loss = ntp_loss(dummy_logits, dummy_input_ids)
print(f"NTP loss: {loss.item():.4f}") # ~log(32000) ≈ 10.37 for random init
print(f"Perplexity: {torch.exp(loss).item():.1f}")
What the Gradient Does¶
During backpropagation, the gradient of the loss with respect to the logit \(z_{t,v}\) is:
This means:
- For the correct token \(v = x_t\): the gradient is \(P_\theta - 1\), which is negative → the logit is pushed up
- For all other tokens: the gradient is \(P_\theta > 0\), which is positive → those logits are pushed down
The model learns by repeatedly increasing the probability of observed tokens and decreasing the probability of unobserved tokens.
Training NTP with Hugging Face Transformers¶
The manual PyTorch code above builds intuition, but in practice we use the Hugging Face Trainer API. This section walks through a complete pipeline: prepare the data, configure the model, train, and generate text.
Step 1: Install dependencies¶
Step 2: Load a pre-trained model and tokenizer¶
We start from a pre-trained GPT-2 checkpoint. Even when the goal is continued pre-training on domain-specific text, initializing from an existing checkpoint is much cheaper than training from scratch.
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
model_name = "gpt2" # 124M parameters
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocab size: {tokenizer.vocab_size}")
Training from scratch
To train a randomly initialized model instead (as done in the workshop notebook), replace from_pretrained with a fresh config:
Step 3: Prepare the dataset¶
NTP training requires long, contiguous chunks of tokens. The standard recipe is:
- Tokenize every document
- Concatenate all token IDs into one long stream
- Slice the stream into fixed-length blocks
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
BLOCK_SIZE = 256
def tokenize_and_chunk(examples):
tokenized = tokenizer(examples["text"], truncation=False)
all_ids = []
for ids in tokenized["input_ids"]:
all_ids.extend(ids)
chunks = [
all_ids[i : i + BLOCK_SIZE]
for i in range(0, len(all_ids) - BLOCK_SIZE, BLOCK_SIZE)
]
return {"input_ids": chunks}
lm_dataset = dataset.map(
tokenize_and_chunk,
batched=True,
remove_columns=dataset.column_names,
batch_size=1000,
)
lm_dataset.set_format("torch")
print(f"Training chunks: {len(lm_dataset)} (each {BLOCK_SIZE} tokens)")
Why concatenate-then-chunk? Documents vary in length. If we padded each document to the block size individually, most tokens in a batch would be padding—wasting computation. Concatenating documents into a continuous stream and slicing into equal-length blocks keeps every token meaningful.
Step 4: Data collator¶
DataCollatorForLanguageModeling with mlm=False handles the shift-by-one logic: it copies input_ids into labels so the model's internal loss function can compare position \(t\) logits against position \(t+1\) tokens.
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # causal LM, not masked LM
)
Step 5: Configure training¶
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="gpt2-ntp",
max_steps=500,
per_device_train_batch_size=8,
learning_rate=5e-4,
warmup_steps=100,
logging_steps=10,
fp16=torch.cuda.is_available(),
save_strategy="no",
report_to="none",
seed=42,
)
| Argument | Purpose |
|---|---|
max_steps |
Total gradient updates. Use num_train_epochs instead for full-epoch training. |
learning_rate |
Peak LR after warmup. 5e-4 is typical for small-scale continued pre-training. |
warmup_steps |
Linear warmup avoids large early updates that destabilize training. |
fp16 |
Mixed-precision training — roughly 2x speed on modern GPUs with no quality loss. |
Step 6: Train¶
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_dataset,
data_collator=data_collator,
)
trainer.train()
The training loss should drop quickly in the first 100 steps (the model learns basic token co-occurrence patterns) and then decrease more slowly as it captures longer-range dependencies.
Step 7: Generate text¶
After training, test the model with autoregressive generation:
def generate(prompt, max_new_tokens=100, temperature=0.7):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
return tokenizer.decode(
output[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
prompts = [
"The patient presented with",
"Recent studies have shown that",
"In this paper, we propose",
]
for p in prompts:
print(f"Prompt: '{p}'")
print(f"Output: {generate(p)}\n")
Step 8: Save and reload¶
model.save_pretrained("gpt2-ntp")
tokenizer.save_pretrained("gpt2-ntp")
reloaded = GPT2LMHeadModel.from_pretrained("gpt2-ntp").to(device)
The saved model can later serve as the starting point for supervised fine-tuning or parameter-efficient fine-tuning.
References¶
- Vaswani et al., Attention Is All You Need
- Radford et al., Language Models are Unsupervised Multitask Learners
- Devlin et al., BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

