Learn AI Series (#71) - Text Generation Techniques

What will I learn
- You will learn how language models actually generate text, one token at a time;
- greedy decoding, beam search, and sampling strategies;
- temperature, top-k, and top-p (nucleus sampling) -- what each parameter controls and why;
- repetition penalty and frequency penalty for avoiding loops;
- speculative decoding: making generation faster without losing quality;
- constrained generation and grammar-guided output for structured responses.
Requirements
- A working modern computer running macOS, Windows or Ubuntu;
- An installed Python 3(.11+) distribution;
- The ambition to learn AI and machine learning.
Difficulty
- Beginner
Curriculum (of the Learn AI Series):
- Learn AI Series (#1) - What Machine Learning Actually Is
- Learn AI Series (#2) - Setting Up Your AI Workbench - Python and NumPy
- Learn AI Series (#3) - Your Data Is Just Numbers - How Machines See the World
- Learn AI Series (#4) - Your First Prediction - No Math, Just Intuition
- Learn AI Series (#5) - Patterns in Data - What "Learning" Actually Looks Like
- Learn AI Series (#6) - From Intuition to Math - Why We Need Formulas
- Learn AI Series (#7) - The Training Loop - See It Work Step by Step
- Learn AI Series (#8) - The Math You Actually Need (Part 1) - Linear Algebra
- Learn AI Series (#9) - The Math You Actually Need (Part 2) - Calculus and Probability
- Learn AI Series (#10) - Your First ML Model - Linear Regression From Scratch
- Learn AI Series (#11) - Making Linear Regression Real
- Learn AI Series (#12) - Classification - Logistic Regression From Scratch
- Learn AI Series (#13) - Evaluation - How to Know If Your Model Actually Works
- Learn AI Series (#14) - Data Preparation - The 80% Nobody Talks About
- Learn AI Series (#15) - Feature Engineering and Selection
- Learn AI Series (#16) - Scikit-Learn - The Standard Library of ML
- Learn AI Series (#17) - Decision Trees - How Machines Make Decisions
- Learn AI Series (#18) - Random Forests - Wisdom of Crowds
- Learn AI Series (#19) - Gradient Boosting - The Kaggle Champion
- Learn AI Series (#20) - Support Vector Machines - Drawing the Perfect Boundary
- Learn AI Series (#21) - Mini Project - Predicting Crypto Market Regimes
- Learn AI Series (#22) - K-Means Clustering - Finding Groups
- Learn AI Series (#23) - Advanced Clustering - Beyond K-Means
- Learn AI Series (#24) - Dimensionality Reduction - PCA
- Learn AI Series (#25) - Advanced Dimensionality Reduction - t-SNE and UMAP
- Learn AI Series (#26) - Anomaly Detection - Finding What Doesn't Belong
- Learn AI Series (#27) - Recommendation Systems - "Users Like You Also Liked..."
- Learn AI Series (#28) - Time Series Fundamentals - When Order Matters
- Learn AI Series (#29) - Time Series Forecasting - Predicting What Comes Next
- Learn AI Series (#30) - Natural Language Processing - Text as Data
- Learn AI Series (#31) - Word Embeddings - Meaning in Numbers
- Learn AI Series (#32) - Bayesian Methods - Thinking in Probabilities
- Learn AI Series (#33) - Ensemble Methods Deep Dive - Stacking and Blending
- Learn AI Series (#34) - ML Engineering - From Notebook to Production
- Learn AI Series (#35) - Data Ethics and Bias in ML
- Learn AI Series (#36) - Mini Project - Complete ML Pipeline
- Learn AI Series (#37) - The Perceptron - Where It All Started
- Learn AI Series (#38) - Neural Networks From Scratch - Forward Pass
- Learn AI Series (#39) - Neural Networks From Scratch - Backpropagation
- Learn AI Series (#40) - Training Neural Networks - Practical Challenges
- Learn AI Series (#41) - Optimization Algorithms - SGD, Momentum, Adam
- Learn AI Series (#42) - PyTorch Fundamentals - Tensors and Autograd
- Learn AI Series (#43) - PyTorch Data and Training
- Learn AI Series (#44) - PyTorch nn.Module - Building Real Networks
- Learn AI Series (#45) - Convolutional Neural Networks - Theory
- Learn AI Series (#46) - CNNs in Practice - Classic to Modern Architectures
- Learn AI Series (#47) - CNN Applications - Detection, Segmentation, Style Transfer
- Learn AI Series (#48) - Recurrent Neural Networks - Sequences
- Learn AI Series (#49) - LSTM and GRU - Solving the Memory Problem
- Learn AI Series (#50) - Sequence-to-Sequence Models
- Learn AI Series (#51) - Attention Mechanisms
- Learn AI Series (#52) - The Transformer Architecture (Part 1)
- Learn AI Series (#53) - The Transformer Architecture (Part 2)
- Learn AI Series (#54) - Vision Transformers
- Learn AI Series (#55) - Generative Adversarial Networks
- Learn AI Series (#56) - Mini Project - Building a Transformer From Scratch
- Learn AI Series (#57) - Language Modeling - Predicting the Next Word
- Learn AI Series (#58) - GPT Architecture - Decoder-Only Transformers
- Learn AI Series (#59) - BERT and Encoder Models
- Learn AI Series (#60) - Training Large Language Models
- Learn AI Series (#61) - Instruction Tuning and Alignment
- Learn AI Series (#62) - Prompt Engineering - Getting the Most from LLMs
- Learn AI Series (#63) - Embeddings and Vector Search
- Learn AI Series (#64) - Retrieval-Augmented Generation (RAG) - Basics
- Learn AI Series (#65) - RAG - Advanced Techniques
- Learn AI Series (#66) - Working with LLM APIs
- Learn AI Series (#67) - Building AI Agents (Part 1) - Foundations
- Learn AI Series (#68) - Building AI Agents (Part 2) - Advanced Patterns
- Learn AI Series (#69) - Fine-Tuning Language Models
- Learn AI Series (#70) - Running Local Models
- Learn AI Series (#71) - Text Generation Techniques (this post)
Learn AI Series (#71) - Text Generation Techniques
Solutions to Episode #70 Exercises
Exercise 1: Model memory calculator.
import math
# Standard architecture ratios
ARCH_TABLE = {
7: {"layers": 32, "dim": 4096},
13: {"layers": 40, "dim": 5120},
70: {"layers": 80, "dim": 8192},
}
def estimate_memory(num_params_b, quant_bits, context_length,
batch_size):
"""Estimate total GPU memory for inference."""
# Model weights
weight_bytes = num_params_b * 1e9 * quant_bits / 8
weight_gb = weight_bytes / 1e9
# KV cache: pick closest architecture
closest = min(ARCH_TABLE.keys(),
key=lambda k: abs(k - num_params_b))
arch = ARCH_TABLE[closest]
n_layers = arch["layers"]
d_model = arch["dim"]
# KV cache: 2 (K+V) * layers * dim * seq_len * batch * 2 bytes
kv_bytes = (2 * n_layers * d_model * context_length
* batch_size * 2)
kv_gb = kv_bytes / 1e9
# Activation overhead (~10% of model weights)
act_gb = weight_gb * 0.10
total_gb = weight_gb + kv_gb + act_gb
print(f"Memory estimate: {num_params_b}B at {quant_bits}-bit, "
f"ctx={context_length}, batch={batch_size}")
print(f" Model weights: {weight_gb:>7.2f} GB")
print(f" KV cache: {kv_gb:>7.2f} GB "
f"({n_layers} layers, dim={d_model})")
print(f" Activations: {act_gb:>7.2f} GB")
print(f" TOTAL: {total_gb:>7.2f} GB")
return total_gb
def fits_in_gpu(total_gb, vram_gb):
"""Check if the config fits in GPU memory."""
if total_gb <= vram_gb:
print(f" -> Fits in {vram_gb} GB VRAM "
f"({vram_gb - total_gb:.1f} GB headroom)")
return True
else:
overshoot = total_gb - vram_gb
print(f" -> Does NOT fit in {vram_gb} GB VRAM "
f"(need {overshoot:.1f} GB more)")
if overshoot < total_gb * 0.2:
print(" Try: shorter context or lower quantization")
else:
print(" Try: smaller model or aggressive quantization")
return False
# Test cases
print("=" * 55)
configs = [
(7, 4, 4096, 1),
(13, 4, 2048, 1),
(70, 4, 4096, 1),
]
for params, bits, ctx, batch in configs:
total = estimate_memory(params, bits, ctx, batch)
fits_in_gpu(total, 24)
print()
The key thing to notice is how the KV cache scales with context length. For a 70B model, going from 4096 to 8192 context adds roughly 5GB just for the key-value pairs -- that's why "out of memory" errors often appear when you increase the context window, not the model itself. The model weights stay constant; it's the KV cache that grows linearly with sequence length.
Exercise 2: Ollama model manager (simulated).
import random
class ModelManager:
"""Simulated local model manager."""
def __init__(self):
self.models = {}
self.usage_log = []
def pull(self, name, size_gb, quant, capabilities):
self.models[name] = {
"size_gb": size_gb,
"quant": quant,
"capabilities": capabilities,
}
print(f"Pulled {name} ({size_gb:.1f} GB, {quant})")
def remove(self, name):
if name in self.models:
freed = self.models[name]["size_gb"]
del self.models[name]
print(f"Removed {name} (freed {freed:.1f} GB)")
else:
print(f"{name} not found")
def list_models(self):
total = sum(m["size_gb"] for m in self.models.values())
print(f"\nInstalled models ({total:.1f} GB total):")
for name, info in sorted(self.models.items()):
caps = ", ".join(info["capabilities"])
print(f" {name:<22} {info['size_gb']:>5.1f} GB "
f"[{caps}]")
def find_best(self, task):
"""Pick smallest model that supports the task."""
candidates = [
(name, info) for name, info in self.models.items()
if task in info["capabilities"]
]
if not candidates:
return None
candidates.sort(key=lambda x: x[1]["size_gb"])
best_name = candidates[0][0]
self.usage_log.append({"model": best_name, "task": task})
return best_name
def usage_report(self):
counts = {}
for entry in self.usage_log:
name = entry["model"]
counts[name] = counts.get(name, 0) + 1
print(f"\nUsage report ({len(self.usage_log)} queries):")
for name, count in sorted(counts.items(),
key=lambda x: -x[1]):
pct = count / len(self.usage_log) * 100
print(f" {name:<22} {count:>4} queries ({pct:.0f}%)")
# Set up
mgr = ModelManager()
mgr.pull("llama3.1:8b", 4.5, "Q4_K_M",
["chat", "reasoning", "summarization"])
mgr.pull("phi3:3.8b", 2.2, "Q4_K_M",
["chat", "classification"])
mgr.pull("codellama:7b", 3.8, "Q4_K_M",
["code", "chat"])
mgr.pull("nomic-embed", 0.3, "FP16",
["embedding"])
mgr.pull("mistral:7b", 4.1, "Q4_K_M",
["chat", "reasoning", "summarization", "classification"])
mgr.list_models()
# Simulate 50 queries
random.seed(42)
tasks = (["chat"] * 15 + ["code"] * 10 + ["classification"] * 10
+ ["embedding"] * 8 + ["summarization"] * 4
+ ["reasoning"] * 3)
random.shuffle(tasks)
for task in tasks:
model = mgr.find_best(task)
mgr.usage_report()
Notice how find_best always picks the smallest model that can handle the task. That's the right default -- you save VRAM, get faster inference, and the quality difference between a 3.8B and a 7B model on straightforward classification is usually negligible. Only move to a bigger model when the smaller one fails your quality bar.
Exercise 3: Quantization quality simulator.
import numpy as np
def simulate_quantization(weights, bits):
"""Uniform quantization: map to 2^bits levels, dequantize back."""
n_levels = 2 ** bits
w_min = weights.min()
w_max = weights.max()
w_range = w_max - w_min
if w_range == 0:
return weights.copy(), 0.0, 0.0, float('inf')
# Quantize: scale to [0, n_levels-1], round, scale back
scaled = (weights - w_min) / w_range * (n_levels - 1)
quantized_int = np.round(scaled).astype(np.int32)
quantized_int = np.clip(quantized_int, 0, n_levels - 1)
dequantized = quantized_int / (n_levels - 1) * w_range + w_min
# Error metrics
error = weights - dequantized
mse = np.mean(error ** 2)
max_err = np.max(np.abs(error))
signal_power = np.mean(weights ** 2)
snr_db = 10 * np.log10(signal_power / mse) if mse > 0 else float('inf')
return dequantized, mse, max_err, snr_db
def awq_simulate(weights, bits, importance_scores):
"""AWQ-style: keep top 1% important weights at fp32."""
threshold = np.percentile(importance_scores, 99)
important_mask = importance_scores >= threshold
result = np.zeros_like(weights)
# Important weights stay at full precision
result[important_mask] = weights[important_mask]
# Rest gets quantized
rest_mask = ~important_mask
if rest_mask.any():
dequant, _, _, _ = simulate_quantization(
weights[rest_mask], bits)
result[rest_mask] = dequant
error = weights - result
mse = np.mean(error ** 2)
max_err = np.max(np.abs(error))
signal_power = np.mean(weights ** 2)
snr_db = (10 * np.log10(signal_power / mse)
if mse > 0 else float('inf'))
return result, mse, max_err, snr_db
# Generate realistic weights
np.random.seed(42)
weights = np.random.normal(0, 0.02, size=10_000).astype(np.float32)
# Uniform quantization comparison
print("Uniform Quantization Comparison")
print(f"{'Bits':>5} {'MSE':>12} {'Max Err':>10} {'SNR (dB)':>10}")
print("-" * 40)
for bits in [2, 3, 4, 5, 6, 8]:
_, mse, max_err, snr = simulate_quantization(weights, bits)
print(f"{bits:>5} {mse:>12.2e} {max_err:>10.6f} {snr:>10.1f}")
# AWQ comparison at 4-bit
print("\n4-bit: Uniform vs AWQ-simulated")
importance = np.abs(weights) # simple importance heuristic
_, uni_mse, _, uni_snr = simulate_quantization(weights, 4)
_, awq_mse, _, awq_snr = awq_simulate(weights, 4, importance)
improvement = (uni_mse - awq_mse) / uni_mse * 100
print(f" Uniform: MSE={uni_mse:.2e}, SNR={uni_snr:.1f} dB")
print(f" AWQ-sim: MSE={awq_mse:.2e}, SNR={awq_snr:.1f} dB")
print(f" MSE improvement: {improvement:.1f}%")
The SNR column is the one to watch. Every extra bit of quantization precision adds roughly 6 dB of signal-to-noise ratio -- that's a well-known result from signal processing theory, and it holds almost exactly for neural network weights too. The AWQ trick of protecting the top 1% most important weights gives you a meaningful MSE reduction for essentially free (1% of weights at full precision is a negligible storage increase). It's clever engineering: don't treat all weights equally when they clearly aren't equally important.
On to today's episode
Here we go! In episode #57 we explored how language models predict the next token. In episode #58 we built the GPT architecture that powers those predictions. And last episode (#70) we learned how to run these models locally on our own hardware. But there's a critical piece we've been hand-waving over this entire time: once the model outputs a probability distribution over its vocabulary, how does it actually pick the next token?
This is NOT a trivial question. The same model, the same prompt, with different decoding settings can produce dry factual text, creative fiction, repetitive loops, or complete gibberish. When someone complains "this model is broken" -- half the time, the model is fine, it's the generation parameters that are wrong. Understanding what temperature, top-k, top-p, and repetition penalty actually do (mathematically, not just "turn it up for creativity") is essential knowledge for working with language models in any capacity.
So let's build every major decoding strategy from scratch ;-)
Greedy decoding: always pick the winner
The absolute simplest strategy. At each step, look at the probability distribution and pick the token with the highest probability. Done.
import torch
import torch.nn.functional as F
def greedy_decode(model, tokenizer, prompt, max_tokens=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_tokens):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] # last position
# Always pick the highest probability token
next_token = torch.argmax(logits, dim=-1, keepdim=True)
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(input_ids[0])
Greedy decoding is fast and deterministic -- the same prompt always produces the same output. But it has a fatal flaw: it gets stuck in loops. Once the model starts a phrase like "the most important thing is," the subsequent tokens keep having high individual probability given the pattern, so the model repeats similar structures over and over. The output reads like a broken record.
There's a deeper problem too. The locally optimal choice at each step doesn't necessarily lead to the globally best sequence. Imagine a sentence where picking the second-most-likely token at position 5 would lead to a much better overall sentence. Greedy decoding will never find it because it never looks beyond the immediate next token.
This is the exact same issue we saw with gradient descent back in episode #41 -- local optima can trap you when you're making myopic decisions. The difference is that in gradient descent we had tools like momentum and learning rate schedules to escape local minima. For text generation, we need different tools entirely.
Beam search: exploring multiple paths at once
Beam search addresses the myopia problem by maintaining multiple candidate sequences (called beams) simultaneously. In stead of commiting to one path, you explore several and keep the overall best ones:
def beam_search(model, tokenizer, prompt, num_beams=4,
max_tokens=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Each beam: (token_ids, cumulative_log_probability)
beams = [(input_ids, 0.0)]
for _ in range(max_tokens):
all_candidates = []
for seq, score in beams:
with torch.no_grad():
logits = model(seq).logits[:, -1, :]
log_probs = F.log_softmax(logits, dim=-1)
# For each beam, consider top-k expansions
top_log_probs, top_ids = torch.topk(
log_probs, num_beams)
for i in range(num_beams):
new_seq = torch.cat(
[seq, top_ids[:, i:i+1]], dim=-1)
new_score = score + top_log_probs[0, i].item()
all_candidates.append((new_seq, new_score))
# Keep only the best beams
all_candidates.sort(key=lambda x: x[1], reverse=True)
beams = all_candidates[:num_beams]
best_seq = beams[0][0]
return tokenizer.decode(best_seq[0])
With 4 beams, you're tracking 4 parallel hypotheses at every step. Each hypothesis gets expanded by the top 4 next tokens, giving 16 candidates, and you prune back down to 4. This is dramatically better than greedy at finding good overall sequences rather than just locally good tokens.
Beam search is the standard for machine translation and text summarization -- tasks where accuracy matters more than creativity. Having said that, beam search output tends to be bland and generic. It optimizes for probability, and the most probable text is often the most boring. "The weather is nice today" has high probability in almost any context. For creative applications, you want something else entirely: controlled randomness.
Sampling with temperature: the creativity knob
In stead of always picking the most likely token, what if we sample from the probability distribution? Higher-probability tokens are still more likely to be selected, but lower-probability tokens get a chance too. This introduces variation -- run the same prompt twice and you get different outputs.
def sample_decode(model, tokenizer, prompt, temperature=1.0,
max_tokens=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_tokens):
with torch.no_grad():
logits = model(input_ids).logits[:, -1, :]
# Apply temperature scaling
scaled_logits = logits / temperature
# Convert to probabilities and sample
probs = F.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(input_ids[0])
Temperature is the single most important generation parameter, and it's often misunderstood. Here's what it does mathematically: it divides the raw logits before the softmax. That one operation changes everything about the output distribution:
- Temperature = 1.0: the original distribution. Nothing changes.
- Temperature < 1.0 (say, 0.3): the logits get divided by a small number, making them larger in magnitude. After softmax, the peaks get even peakier. High-probability tokens dominate even more. Output is focused and predictable.
- Temperature > 1.0 (say, 1.5): the logits get divided by a large number, flattening them toward zero. After softmax, the distribution becomes more uniform. Rare tokens get a bigger chance. Output is creative but potentially incoherent.
- Temperature -> 0: approaches greedy decoding. The highest-probability token takes over completely.
Let me make this concrete with actual numbers:
# Temperature effects on a real distribution
logits = torch.tensor([2.0, 1.0, 0.5, -1.0])
for temp in [0.1, 0.5, 1.0, 2.0]:
probs = F.softmax(logits / temp, dim=-1)
print(f"T={temp}: {probs.numpy().round(3)}")
# T=0.1: [1. 0. 0. 0. ] <- almost greedy
# T=0.5: [0.867 0.117 0.015 0.001] <- focused
# T=1.0: [0.563 0.207 0.126 0.028] <- balanced
# T=2.0: [0.392 0.252 0.197 0.088] <- spread out
See the pattern? At T=0.1, the first token has essentially 100% of the probability mass -- you'll always pick it. At T=2.0, even the least likely token (0.088 = 8.8%) has a reasonable shot. The range between 0.5 and 1.0 is where most practical applications live.
If you remember softmax from episode #38 (neural networks forward pass) and episode #51 (attention mechanisms), you already understand the math here. Temperature scaling is just a pre-processing step that modifies how "sharp" or "flat" the softmax output is. Same function, different input scale, dramatically different behavior ;-)
Top-k sampling: cutting the tail
Pure sampling has a problem even with reasonable temperature: extremely unlikely tokens still have a non-zero chance of being picked. A token with 0.001% probability will eventually appear if you generate enough text, and that one bad token can derail the entire sequence into nonsense.
Top-k sampling restricts the candidate pool to only the k most probable tokens:
def top_k_sample(logits, k=50, temperature=1.0):
scaled = logits / temperature
# Find the top-k values and their indices
top_k_values, top_k_indices = torch.topk(scaled, k)
# Set everything else to -infinity (zero probability after softmax)
filtered = torch.full_like(scaled, float('-inf'))
filtered.scatter_(1, top_k_indices, top_k_values)
probs = F.softmax(filtered, dim=-1)
return torch.multinomial(probs, num_samples=1)
With k=50, only the 50 most likely next tokens are considered. Everything else is masked out completely -- zero chance of selection. This eliminates the low-probability noise while keeping diversity among the reasonable candidates.
But k=50 is a fixed number, and that's the problem. Consider two situations:
Situation A: The model is very confident. One token has 95% probability, and the remaining 49 tokens share the other 5%. Here k=50 allows 49 tokens the model thinks are almost certainly wrong.
Situation B: The model is uncertain. 200 tokens each have roughly 0.5% probability. Here k=50 cuts off 150 perfectly reasonable choices, artificially restricting the output.
A fixed k can't handle both situations well. We need the candidate pool to adapt to the model's confidence.
Top-p (nucleus) sampling: the adaptive solution
Top-p sampling (Holtzman et al., 2019) elegantly solves the fixed-k problem. In stead of picking a fixed number of candidates, it includes the smallest set of tokens whose cumulative probability reaches p:
def top_p_sample(logits, p=0.9, temperature=1.0):
scaled = logits / temperature
probs = F.softmax(scaled, dim=-1)
# Sort by probability (highest first)
sorted_probs, sorted_indices = torch.sort(
probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Find where cumulative probability exceeds p
# Keep the first token that pushes us over the threshold
sorted_mask = cumulative_probs - sorted_probs > p
sorted_probs[sorted_mask] = 0.0
# Renormalize so they sum to 1 again
sorted_probs = sorted_probs / sorted_probs.sum()
# Sample from the filtered distribution
token_idx = torch.multinomial(sorted_probs, num_samples=1)
return sorted_indices.gather(-1, token_idx)
With p=0.9: if one token has 95% probability, only that token is in the nucleus (the name comes from thinking of the high-probability core of the distribution). If 100 tokens each have ~1% probability, all 100 are included. The candidate pool automatically adapts to how confident the model is at each step.
This is a really elegant idea when you think about it. The nucleus grows and shrinks based on what the model actually believes, not on an arbitrary fixed threshold. In practice, most LLM APIs combine temperature with top-p -- the typical defaults (temperature 0.7, top_p 0.9) produce coherent, slightly creative text. For factual tasks, lower both. For creative writing, raise temperature a bit.
def combined_sampling(logits, temperature=0.7, top_k=50,
top_p=0.9):
"""Full generation pipeline: temperature -> top-k -> top-p."""
# Step 1: temperature scaling
scaled = logits / temperature
# Step 2: top-k filtering (optional, set k=0 to skip)
if top_k > 0:
top_k_vals, top_k_idx = torch.topk(scaled, top_k)
filtered = torch.full_like(scaled, float('-inf'))
filtered.scatter_(1, top_k_idx, top_k_vals)
else:
filtered = scaled
# Step 3: top-p filtering
probs = F.softmax(filtered, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum - sorted_probs > top_p
sorted_probs[mask] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum()
# Step 4: sample
token_idx = torch.multinomial(sorted_probs, num_samples=1)
return sorted_idx.gather(-1, token_idx)
Note the order matters: temperature first (scales the logits), then top-k (removes improbable candidates), then top-p (adapts the remaining pool). This is how most production systems apply these parameters.
Repetition penalty: breaking the loop
Even with good sampling settings, models tend to fall into repetitive patterns. You've probably seen this: the model starts saying "the key advantage is" and then three paragraphs later it says "the key advantage is" again and again with slight variations. Repetition penalty (Keskar et al., 2019) discourages tokens that have already appeared in the generated text:
def apply_repetition_penalty(logits, generated_ids, penalty=1.2):
"""Penalize tokens that already appeared in the output."""
for token_id in set(generated_ids):
if logits[0, token_id] > 0:
logits[0, token_id] /= penalty # reduce positive logit
else:
logits[0, token_id] *= penalty # make negative more negative
return logits
A penalty of 1.0 means no change. A penalty of 1.2 reduces the score of already-seen tokens by about 20%. The asymmetric handling (divide positive logits, multiply negative ones) is important -- in both cases the token becomes less likely after softmax.
Values above 1.5 can make the output aggressively avoid any repitition, sometimes to the point of incoherence as the model runs out of natural word choices. Common words like "the", "is", "and" get penalized even though they should repeat -- they're function words, not content repetition.
Frequency penalty is a smarter variant. In stead of applying a flat penalty to any repeated token, the penalty scales with how many times the token has appeared:
def apply_frequency_penalty(logits, token_counts, penalty=0.5):
"""Penalty scales with how often each token has appeared."""
for token_id, count in token_counts.items():
logits[0, token_id] -= penalty * count
return logits
def generate_with_frequency_penalty(model, tokenizer, prompt,
penalty=0.5, max_tokens=100):
"""Full generation loop with frequency tracking."""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
token_counts = {}
for _ in range(max_tokens):
with torch.no_grad():
logits = model(input_ids).logits[:, -1, :]
# Apply frequency-based penalty
logits = apply_frequency_penalty(logits, token_counts,
penalty)
# Sample with temperature
probs = F.softmax(logits / 0.7, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == tokenizer.eos_token_id:
break
# Update frequency count
tid = next_token.item()
token_counts[tid] = token_counts.get(tid, 0) + 1
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(input_ids[0])
A word used once gets a small penalty. A word used ten times gets a penalty ten times larger. This allows natural repetition of common function words while progressively discouraging content words that keep showing up unnaturally.
Speculative decoding: speed without sacrifice
Here's a different kind of problem. Standard autoregressive generation is inherently sequential -- each token requires a full forward pass through the model, and each pass depends on the previous token. For a 70B parameter model, generating 100 tokens means 100 sequential forward passes. That's slow.
Speculative decoding (Leviathan et al., 2022) is a genuinely clever trick. Use a smaller, faster draft model to propose multiple tokens at once, then verify them with the large target model in a single forward pass:
The algorithm:
1. Draft model (fast, ~1B params) generates 5 tokens: [A, B, C, D, E]
2. Target model (slow, ~70B params) processes all 5 at once
3. Target model accepts [A, B, C] but rejects D
4. Keep [A, B, C], resample position 4 from the target model
5. Repeat from position 4
The key insight: verifying N tokens in parallel takes roughly the same time as generating 1 token. A transformer processes all positions simultaneously in one forward pass (remember the parallelism advantage from episode #52?). So if the draft model's guesses match the target model often enough -- and for a well-chosen draft model, they match 60-80% of the time -- you get 2-3x speedup with mathematically identical output distribution. Not "similar" output -- identical. The acceptance/rejection scheme guarantees this.
def speculative_decode(target_model, draft_model, tokenizer,
prompt, n_speculative=5, max_tokens=100):
"""Speculative decoding: draft proposes, target verifies."""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
generated = 0
while generated < max_tokens:
# Phase 1: draft model generates n_speculative tokens
draft_ids = input_ids.clone()
draft_probs_list = []
for _ in range(n_speculative):
with torch.no_grad():
logits = draft_model(draft_ids).logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
draft_probs_list.append(probs)
draft_ids = torch.cat([draft_ids, next_tok], dim=-1)
# Phase 2: target model scores all positions at once
with torch.no_grad():
target_logits = target_model(draft_ids).logits
# Phase 3: accept/reject each drafted token
n_input = input_ids.shape[1]
n_accepted = 0
for i in range(n_speculative):
pos = n_input + i
target_probs = F.softmax(
target_logits[:, pos - 1, :], dim=-1)
draft_token = draft_ids[0, pos].item()
draft_p = draft_probs_list[i][0, draft_token].item()
target_p = target_probs[0, draft_token].item()
# Accept if target agrees (or with probability ratio)
if target_p >= draft_p:
n_accepted += 1
else:
# Probabilistic acceptance
if torch.rand(1).item() < target_p / draft_p:
n_accepted += 1
else:
break
# Keep accepted tokens + sample one from target
input_ids = draft_ids[:, :n_input + n_accepted]
# Sample the next token from target model
target_probs = F.softmax(
target_logits[:, n_input + n_accepted - 1, :],
dim=-1)
correction = torch.multinomial(target_probs,
num_samples=1)
input_ids = torch.cat([input_ids, correction], dim=-1)
generated += n_accepted + 1
if tokenizer.eos_token_id in input_ids[0, n_input:]:
break
return tokenizer.decode(input_ids[0])
The probabilistic acceptance (target_p / draft_p) is what makes this mathematically exact. If the draft model says token X has probability 0.3 and the target says 0.2, we accept with probability 0.2/0.3 = 67%. This rejection sampling approach guarantees that the final distribution matches the target model exactly, regardless of how bad the draft model is. A bad draft model just means more rejections (lower speedup), not worse output quality.
Constrained generation: following the rules
Sometimes you don't just want plausible text -- you need the output to follow a specific format. Valid JSON. A number between 1 and 10. One of three predefined categories. A SQL query. Standard sampling makes no format guarantees whatsoever.
Grammar-guided generation solves this by restricting the model at each step to only produce tokens that are valid according to a formal grammar:
def grammar_constrained_generate(model, tokenizer, prompt,
grammar, max_tokens=200):
"""Generate text that conforms to a formal grammar."""
input_ids = tokenizer.encode(prompt, return_tensors="pt")
parser_state = grammar.initial_state()
for _ in range(max_tokens):
with torch.no_grad():
logits = model(input_ids).logits[:, -1, :]
# Ask the grammar which tokens are legal right now
allowed_tokens = grammar.allowed_tokens(parser_state)
# Mask out everything else
mask = torch.full_like(logits, float('-inf'))
mask[0, allowed_tokens] = 0
constrained_logits = logits + mask
# Sample from allowed tokens only
probs = F.softmax(constrained_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Advance the grammar parser
parser_state = grammar.advance(
parser_state, next_token.item())
input_ids = torch.cat([input_ids, next_token], dim=-1)
if grammar.is_complete(parser_state):
break
return tokenizer.decode(input_ids[0])
The model still has creative freedom within the grammar's constraints -- it decides what values to use, how to structure the content, what words to pick. But the structure is guaranteed. Every opening brace gets a closing brace. Every key is a quoted string. Every number is actually a number.
Tools like llama.cpp's GBNF grammars and Outlines (a Python library) implement this efficiently. You define a grammar -- JSON schema, regex pattern, or context-free grammar -- and the engine guarantees conformance.
Back in episode #62 we discussed structured outputs from LLM APIs. Constrained generation is the mechanism that makes that possible under the hood. When you tell an API "respond in JSON matching this schema," the server is running something like the code above (or a more optimized version of it).
Let's implement a simple JSON grammar constraint to show how it works in practice:
class SimpleJsonGrammar:
"""Minimal JSON grammar for demonstration."""
def __init__(self, required_keys):
self.required_keys = required_keys
self.states = [
"start", # expecting {
"key", # expecting "key_name"
"colon", # expecting :
"value", # expecting a value
"comma_or_end" # expecting , or }
]
def initial_state(self):
return {
"state": "start",
"keys_seen": [],
"depth": 0,
}
def allowed_tokens(self, state):
"""Return token IDs that are valid in current state."""
# This is simplified -- real implementations use
# the tokenizer's full vocabulary mapped to grammar
# terminals
s = state["state"]
if s == "start":
return ["{"]
elif s == "key":
remaining = [k for k in self.required_keys
if k not in state["keys_seen"]]
return [f'"{k}"' for k in remaining]
elif s == "colon":
return [":"]
elif s == "value":
return ["string", "number", "true", "false", "null"]
elif s == "comma_or_end":
if len(state["keys_seen"]) < len(self.required_keys):
return [","]
return ["}"]
return []
def advance(self, state, token):
"""Move the grammar state machine forward."""
new_state = dict(state)
new_state["keys_seen"] = list(state["keys_seen"])
s = state["state"]
if s == "start":
new_state["state"] = "key"
new_state["depth"] = 1
elif s == "key":
new_state["keys_seen"].append(token)
new_state["state"] = "colon"
elif s == "colon":
new_state["state"] = "value"
elif s == "value":
new_state["state"] = "comma_or_end"
elif s == "comma_or_end":
if token == ",":
new_state["state"] = "key"
else:
new_state["depth"] = 0
return new_state
def is_complete(self, state):
return state["depth"] == 0 and len(state["keys_seen"]) > 0
# Demo
grammar = SimpleJsonGrammar(
required_keys=["name", "score", "passed"])
state = grammar.initial_state()
print(f"Start -> allowed: {grammar.allowed_tokens(state)}")
state = grammar.advance(state, "{")
print(f'After {{ -> allowed: {grammar.allowed_tokens(state)}')
state = grammar.advance(state, '"name"')
print(f'After "name" -> allowed: {grammar.allowed_tokens(state)}')
This is a simplified illustration. Real grammar engines operate at the token level (not the string level), handling the messy reality that a single word might span multiple tokens and a single token might be part of multiple grammar productions. But the principle is identical: maintain a parser state, filter the vocabulary at each step, let the model fill in the content.
Putting it all together: a complete generation pipeline
Let me show you a class that combines everything we've covered into a single reusable pipeline. This is the kind of code you'd actually use in a project:
import torch
import torch.nn.functional as F
from collections import Counter
class TextGenerator:
"""Complete text generation pipeline with all strategies."""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, prompt, max_tokens=200, strategy="sample",
temperature=0.7, top_k=50, top_p=0.9,
repetition_penalty=1.0, frequency_penalty=0.0,
num_beams=4):
"""Generate text with the specified strategy."""
input_ids = self.tokenizer.encode(
prompt, return_tensors="pt")
if strategy == "greedy":
return self._greedy(input_ids, max_tokens)
elif strategy == "beam":
return self._beam(input_ids, max_tokens, num_beams)
elif strategy == "sample":
return self._sample(
input_ids, max_tokens, temperature, top_k,
top_p, repetition_penalty, frequency_penalty)
else:
raise ValueError(f"Unknown strategy: {strategy}")
def _greedy(self, input_ids, max_tokens):
for _ in range(max_tokens):
logits = self._get_logits(input_ids)
next_token = torch.argmax(logits, dim=-1,
keepdim=True)
if next_token.item() == self.tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token],
dim=-1)
return self.tokenizer.decode(input_ids[0])
def _beam(self, input_ids, max_tokens, num_beams):
beams = [(input_ids, 0.0)]
for _ in range(max_tokens):
candidates = []
for seq, score in beams:
logits = self._get_logits(seq)
log_probs = F.log_softmax(logits, dim=-1)
top_lp, top_ids = torch.topk(
log_probs, num_beams)
for i in range(num_beams):
new_seq = torch.cat(
[seq, top_ids[:, i:i+1]], dim=-1)
new_score = score + top_lp[0, i].item()
candidates.append((new_seq, new_score))
candidates.sort(key=lambda x: x[1], reverse=True)
beams = candidates[:num_beams]
return self.tokenizer.decode(beams[0][0][0])
def _sample(self, input_ids, max_tokens, temperature,
top_k, top_p, rep_penalty, freq_penalty):
token_counts = Counter()
generated_ids = []
for _ in range(max_tokens):
logits = self._get_logits(input_ids)
# Repetition penalty
if rep_penalty != 1.0 and generated_ids:
for tid in set(generated_ids):
if logits[0, tid] > 0:
logits[0, tid] /= rep_penalty
else:
logits[0, tid] *= rep_penalty
# Frequency penalty
if freq_penalty > 0:
for tid, count in token_counts.items():
logits[0, tid] -= freq_penalty * count
# Temperature
scaled = logits / temperature
# Top-k
if top_k > 0:
kth_val = torch.topk(scaled, top_k).values[
0, -1]
scaled[scaled < kth_val] = float('-inf')
# Top-p
probs = F.softmax(scaled, dim=-1)
sorted_p, sorted_i = torch.sort(
probs, descending=True)
cumsum = torch.cumsum(sorted_p, dim=-1)
mask = cumsum - sorted_p > top_p
sorted_p[mask] = 0.0
sorted_p = sorted_p / sorted_p.sum()
idx = torch.multinomial(sorted_p, num_samples=1)
next_token = sorted_i.gather(-1, idx)
if next_token.item() == self.tokenizer.eos_token_id:
break
tid = next_token.item()
generated_ids.append(tid)
token_counts[tid] += 1
input_ids = torch.cat([input_ids, next_token],
dim=-1)
return self.tokenizer.decode(input_ids[0])
def _get_logits(self, input_ids):
with torch.no_grad():
return self.model(input_ids).logits[:, -1, :]
# Usage demonstration (requires a loaded model)
# gen = TextGenerator(model, tokenizer)
#
# # Factual output
# gen.generate("Explain gravity:", strategy="greedy")
#
# # Creative output
# gen.generate("Once upon a time,", strategy="sample",
# temperature=0.9, top_p=0.95)
#
# # Structured output (beam search)
# gen.generate("Translate to French:", strategy="beam",
# num_beams=5)
#
# # Anti-repetition
# gen.generate("Write a paragraph about AI:",
# strategy="sample", repetition_penalty=1.3,
# frequency_penalty=0.5)
# Parameter guide
params = [
("temperature", "0.1-0.3", "0.7-0.9", "1.0-1.5"),
("top_k", "10-20", "40-60", "80-200"),
("top_p", "0.5-0.7", "0.85-0.95","0.95-1.0"),
("rep_penalty", "1.0", "1.1-1.2", "1.3-1.5"),
("freq_penalty","0.0", "0.3-0.5", "0.8-1.0"),
]
print(f"{'Parameter':<15} {'Factual':>12} {'Balanced':>12} "
f"{'Creative':>12}")
print("-" * 55)
for name, low, mid, high in params:
print(f"{name:<15} {low:>12} {mid:>12} {high:>12}")
Key takeaways
- Greedy decoding always picks the highest-probability token. Fast and deterministic, but produces repetitive, bland output because it never explores alternatives;
- beam search maintains multiple candidate sequences and finds globally better results -- the standard for translation and summarization, but too boring for creative tasks;
- temperature is the master creativity control. Low values (0.1-0.3) for factual precision, high values (0.9-1.5) for creative variation. It scales the logits before softmax, sharpening or flattening the distribution;
- top-k limits candidates to the k most likely tokens, cutting long-tail noise. But k is fixed, which is suboptimal when model confidence varies;
- top-p (nucleus sampling) is the adaptive answer: it includes the smallest set of tokens whose cumulative probability reaches p. Generally preferred over top-k because it adapts to the distribution shape;
- repetition and frequency penalties prevent the model from looping. Repetition penalty applies a flat discount to seen tokens; frequency penalty scales with how often each token appeared -- the latter is usually better behaved;
- speculative decoding uses a small draft model to propose tokens verified by a larger model, achieving 2-3x speedup with mathematically identical output quality -- the verification is exact, not approximate;
- constrained generation with formal grammars guarantees structured output (valid JSON, specific formats) while preserving the model's freedom within those constraints. This is what powers the "structured output" features in modern LLM APIs.
Exercises
Exercise 1: Build a temperature visualizer. Create a function visualize_temperature(logits, temperatures) that takes a 1D tensor of logits (at least 10 values) and a list of temperatures (e.g., [0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]). For each temperature: apply the scaling, compute softmax probabilities, and calculate three metrics: (a) entropy of the distribution (higher = more random), (b) probability of the top token (how dominant the best choice is), (c) effective vocabulary size (number of tokens with probability > 1% of the total). Print a table with all three metrics for each temperature. Then add a recommend_temperature(task_type) function that returns the recommended temperature range for "factual", "balanced", and "creative" tasks based on entropy thresholds.
Exercise 2: Build a repetition detector and auto-penalizer. Create a class RepetitionGuard that monitors generated text for repetitive patterns. It should track: (a) individual token frequencies, (b) bigram (2-token) frequencies, and (c) the longest repeated substring in the last N tokens. Implement check(generated_ids) that returns a dictionary with {"token_repeats": [...], "bigram_repeats": [...], "longest_repeat": int, "severity": float} where severity is 0.0 (no repetition) to 1.0 (severe). Implement adaptive_penalty(severity) that maps severity to a repetition penalty value (1.0 at severity 0.0, scaling up to 1.5 at severity 1.0). Test with a simulated token sequence that has intentional repetition: generate 100 token IDs where tokens 50-70 are a repeating pattern of [10, 20, 30, 10, 20, 30, ...] and the rest are random. Show that the guard detects the repetition and recommends an appropriate penalty.
Exercise 3: Build a decoding strategy comparator. Create a class DecodingComparator that takes a probability distribution (a tensor of probabilities over a vocabulary) and simulates N samples from it using different strategies: greedy (1 sample, always argmax), pure sampling, top-k (k=10, 20, 50), and top-p (p=0.5, 0.9, 0.95). For each strategy, run 1000 sampling trials and measure: (a) how many unique tokens were selected across all trials, (b) the Shannon entropy of the empirical sample distribution, and (c) whether the top-1 token from the original distribution was selected at least once. Use a synthetic distribution: 1000 tokens where the probabilities follow a Zipf distribution (token i has probability proportional to 1/i). Print a comparison table. Also compute the KL divergence between the original distribution and each strategy's empirical distribution to quantify how much each strategy distorts the model's beliefs.