Learn AI Series (#71) - Text Generation Techniques

in StemSocial5 hours ago

Learn AI Series (#71) - Text Generation Techniques

ai-banner.png

What will I learn

  • You will learn how language models actually generate text, one token at a time;
  • greedy decoding, beam search, and sampling strategies;
  • temperature, top-k, and top-p (nucleus sampling) -- what each parameter controls and why;
  • repetition penalty and frequency penalty for avoiding loops;
  • speculative decoding: making generation faster without losing quality;
  • constrained generation and grammar-guided output for structured responses.

Requirements

  • A working modern computer running macOS, Windows or Ubuntu;
  • An installed Python 3(.11+) distribution;
  • The ambition to learn AI and machine learning.

Difficulty

  • Beginner

Curriculum (of the Learn AI Series):

Learn AI Series (#71) - Text Generation Techniques

Solutions to Episode #70 Exercises

Exercise 1: Model memory calculator.

import math

# Standard architecture ratios
ARCH_TABLE = {
    7:  {"layers": 32, "dim": 4096},
    13: {"layers": 40, "dim": 5120},
    70: {"layers": 80, "dim": 8192},
}


def estimate_memory(num_params_b, quant_bits, context_length,
                    batch_size):
    """Estimate total GPU memory for inference."""
    # Model weights
    weight_bytes = num_params_b * 1e9 * quant_bits / 8
    weight_gb = weight_bytes / 1e9

    # KV cache: pick closest architecture
    closest = min(ARCH_TABLE.keys(),
                  key=lambda k: abs(k - num_params_b))
    arch = ARCH_TABLE[closest]
    n_layers = arch["layers"]
    d_model = arch["dim"]

    # KV cache: 2 (K+V) * layers * dim * seq_len * batch * 2 bytes
    kv_bytes = (2 * n_layers * d_model * context_length
                * batch_size * 2)
    kv_gb = kv_bytes / 1e9

    # Activation overhead (~10% of model weights)
    act_gb = weight_gb * 0.10

    total_gb = weight_gb + kv_gb + act_gb

    print(f"Memory estimate: {num_params_b}B at {quant_bits}-bit, "
          f"ctx={context_length}, batch={batch_size}")
    print(f"  Model weights:  {weight_gb:>7.2f} GB")
    print(f"  KV cache:       {kv_gb:>7.2f} GB "
          f"({n_layers} layers, dim={d_model})")
    print(f"  Activations:    {act_gb:>7.2f} GB")
    print(f"  TOTAL:          {total_gb:>7.2f} GB")
    return total_gb


def fits_in_gpu(total_gb, vram_gb):
    """Check if the config fits in GPU memory."""
    if total_gb <= vram_gb:
        print(f"  -> Fits in {vram_gb} GB VRAM "
              f"({vram_gb - total_gb:.1f} GB headroom)")
        return True
    else:
        overshoot = total_gb - vram_gb
        print(f"  -> Does NOT fit in {vram_gb} GB VRAM "
              f"(need {overshoot:.1f} GB more)")
        if overshoot < total_gb * 0.2:
            print("     Try: shorter context or lower quantization")
        else:
            print("     Try: smaller model or aggressive quantization")
        return False


# Test cases
print("=" * 55)
configs = [
    (7,  4, 4096, 1),
    (13, 4, 2048, 1),
    (70, 4, 4096, 1),
]
for params, bits, ctx, batch in configs:
    total = estimate_memory(params, bits, ctx, batch)
    fits_in_gpu(total, 24)
    print()

The key thing to notice is how the KV cache scales with context length. For a 70B model, going from 4096 to 8192 context adds roughly 5GB just for the key-value pairs -- that's why "out of memory" errors often appear when you increase the context window, not the model itself. The model weights stay constant; it's the KV cache that grows linearly with sequence length.

Exercise 2: Ollama model manager (simulated).

import random

class ModelManager:
    """Simulated local model manager."""

    def __init__(self):
        self.models = {}
        self.usage_log = []

    def pull(self, name, size_gb, quant, capabilities):
        self.models[name] = {
            "size_gb": size_gb,
            "quant": quant,
            "capabilities": capabilities,
        }
        print(f"Pulled {name} ({size_gb:.1f} GB, {quant})")

    def remove(self, name):
        if name in self.models:
            freed = self.models[name]["size_gb"]
            del self.models[name]
            print(f"Removed {name} (freed {freed:.1f} GB)")
        else:
            print(f"{name} not found")

    def list_models(self):
        total = sum(m["size_gb"] for m in self.models.values())
        print(f"\nInstalled models ({total:.1f} GB total):")
        for name, info in sorted(self.models.items()):
            caps = ", ".join(info["capabilities"])
            print(f"  {name:<22} {info['size_gb']:>5.1f} GB  "
                  f"[{caps}]")

    def find_best(self, task):
        """Pick smallest model that supports the task."""
        candidates = [
            (name, info) for name, info in self.models.items()
            if task in info["capabilities"]
        ]
        if not candidates:
            return None
        candidates.sort(key=lambda x: x[1]["size_gb"])
        best_name = candidates[0][0]
        self.usage_log.append({"model": best_name, "task": task})
        return best_name

    def usage_report(self):
        counts = {}
        for entry in self.usage_log:
            name = entry["model"]
            counts[name] = counts.get(name, 0) + 1
        print(f"\nUsage report ({len(self.usage_log)} queries):")
        for name, count in sorted(counts.items(),
                                  key=lambda x: -x[1]):
            pct = count / len(self.usage_log) * 100
            print(f"  {name:<22} {count:>4} queries ({pct:.0f}%)")


# Set up
mgr = ModelManager()
mgr.pull("llama3.1:8b", 4.5, "Q4_K_M",
         ["chat", "reasoning", "summarization"])
mgr.pull("phi3:3.8b", 2.2, "Q4_K_M",
         ["chat", "classification"])
mgr.pull("codellama:7b", 3.8, "Q4_K_M",
         ["code", "chat"])
mgr.pull("nomic-embed", 0.3, "FP16",
         ["embedding"])
mgr.pull("mistral:7b", 4.1, "Q4_K_M",
         ["chat", "reasoning", "summarization", "classification"])

mgr.list_models()

# Simulate 50 queries
random.seed(42)
tasks = (["chat"] * 15 + ["code"] * 10 + ["classification"] * 10
         + ["embedding"] * 8 + ["summarization"] * 4
         + ["reasoning"] * 3)
random.shuffle(tasks)

for task in tasks:
    model = mgr.find_best(task)

mgr.usage_report()

Notice how find_best always picks the smallest model that can handle the task. That's the right default -- you save VRAM, get faster inference, and the quality difference between a 3.8B and a 7B model on straightforward classification is usually negligible. Only move to a bigger model when the smaller one fails your quality bar.

Exercise 3: Quantization quality simulator.

import numpy as np

def simulate_quantization(weights, bits):
    """Uniform quantization: map to 2^bits levels, dequantize back."""
    n_levels = 2 ** bits
    w_min = weights.min()
    w_max = weights.max()
    w_range = w_max - w_min

    if w_range == 0:
        return weights.copy(), 0.0, 0.0, float('inf')

    # Quantize: scale to [0, n_levels-1], round, scale back
    scaled = (weights - w_min) / w_range * (n_levels - 1)
    quantized_int = np.round(scaled).astype(np.int32)
    quantized_int = np.clip(quantized_int, 0, n_levels - 1)
    dequantized = quantized_int / (n_levels - 1) * w_range + w_min

    # Error metrics
    error = weights - dequantized
    mse = np.mean(error ** 2)
    max_err = np.max(np.abs(error))
    signal_power = np.mean(weights ** 2)
    snr_db = 10 * np.log10(signal_power / mse) if mse > 0 else float('inf')

    return dequantized, mse, max_err, snr_db


def awq_simulate(weights, bits, importance_scores):
    """AWQ-style: keep top 1% important weights at fp32."""
    threshold = np.percentile(importance_scores, 99)
    important_mask = importance_scores >= threshold

    result = np.zeros_like(weights)
    # Important weights stay at full precision
    result[important_mask] = weights[important_mask]

    # Rest gets quantized
    rest_mask = ~important_mask
    if rest_mask.any():
        dequant, _, _, _ = simulate_quantization(
            weights[rest_mask], bits)
        result[rest_mask] = dequant

    error = weights - result
    mse = np.mean(error ** 2)
    max_err = np.max(np.abs(error))
    signal_power = np.mean(weights ** 2)
    snr_db = (10 * np.log10(signal_power / mse)
              if mse > 0 else float('inf'))
    return result, mse, max_err, snr_db


# Generate realistic weights
np.random.seed(42)
weights = np.random.normal(0, 0.02, size=10_000).astype(np.float32)

# Uniform quantization comparison
print("Uniform Quantization Comparison")
print(f"{'Bits':>5} {'MSE':>12} {'Max Err':>10} {'SNR (dB)':>10}")
print("-" * 40)
for bits in [2, 3, 4, 5, 6, 8]:
    _, mse, max_err, snr = simulate_quantization(weights, bits)
    print(f"{bits:>5} {mse:>12.2e} {max_err:>10.6f} {snr:>10.1f}")

# AWQ comparison at 4-bit
print("\n4-bit: Uniform vs AWQ-simulated")
importance = np.abs(weights)  # simple importance heuristic
_, uni_mse, _, uni_snr = simulate_quantization(weights, 4)
_, awq_mse, _, awq_snr = awq_simulate(weights, 4, importance)
improvement = (uni_mse - awq_mse) / uni_mse * 100

print(f"  Uniform:  MSE={uni_mse:.2e}, SNR={uni_snr:.1f} dB")
print(f"  AWQ-sim:  MSE={awq_mse:.2e}, SNR={awq_snr:.1f} dB")
print(f"  MSE improvement: {improvement:.1f}%")

The SNR column is the one to watch. Every extra bit of quantization precision adds roughly 6 dB of signal-to-noise ratio -- that's a well-known result from signal processing theory, and it holds almost exactly for neural network weights too. The AWQ trick of protecting the top 1% most important weights gives you a meaningful MSE reduction for essentially free (1% of weights at full precision is a negligible storage increase). It's clever engineering: don't treat all weights equally when they clearly aren't equally important.

On to today's episode

Here we go! In episode #57 we explored how language models predict the next token. In episode #58 we built the GPT architecture that powers those predictions. And last episode (#70) we learned how to run these models locally on our own hardware. But there's a critical piece we've been hand-waving over this entire time: once the model outputs a probability distribution over its vocabulary, how does it actually pick the next token?

This is NOT a trivial question. The same model, the same prompt, with different decoding settings can produce dry factual text, creative fiction, repetitive loops, or complete gibberish. When someone complains "this model is broken" -- half the time, the model is fine, it's the generation parameters that are wrong. Understanding what temperature, top-k, top-p, and repetition penalty actually do (mathematically, not just "turn it up for creativity") is essential knowledge for working with language models in any capacity.

So let's build every major decoding strategy from scratch ;-)

Greedy decoding: always pick the winner

The absolute simplest strategy. At each step, look at the probability distribution and pick the token with the highest probability. Done.

import torch
import torch.nn.functional as F

def greedy_decode(model, tokenizer, prompt, max_tokens=100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            logits = outputs.logits[:, -1, :]  # last position

        # Always pick the highest probability token
        next_token = torch.argmax(logits, dim=-1, keepdim=True)

        if next_token.item() == tokenizer.eos_token_id:
            break

        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0])

Greedy decoding is fast and deterministic -- the same prompt always produces the same output. But it has a fatal flaw: it gets stuck in loops. Once the model starts a phrase like "the most important thing is," the subsequent tokens keep having high individual probability given the pattern, so the model repeats similar structures over and over. The output reads like a broken record.

There's a deeper problem too. The locally optimal choice at each step doesn't necessarily lead to the globally best sequence. Imagine a sentence where picking the second-most-likely token at position 5 would lead to a much better overall sentence. Greedy decoding will never find it because it never looks beyond the immediate next token.

This is the exact same issue we saw with gradient descent back in episode #41 -- local optima can trap you when you're making myopic decisions. The difference is that in gradient descent we had tools like momentum and learning rate schedules to escape local minima. For text generation, we need different tools entirely.

Beam search: exploring multiple paths at once

Beam search addresses the myopia problem by maintaining multiple candidate sequences (called beams) simultaneously. In stead of commiting to one path, you explore several and keep the overall best ones:

def beam_search(model, tokenizer, prompt, num_beams=4,
                max_tokens=100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # Each beam: (token_ids, cumulative_log_probability)
    beams = [(input_ids, 0.0)]

    for _ in range(max_tokens):
        all_candidates = []

        for seq, score in beams:
            with torch.no_grad():
                logits = model(seq).logits[:, -1, :]
                log_probs = F.log_softmax(logits, dim=-1)

            # For each beam, consider top-k expansions
            top_log_probs, top_ids = torch.topk(
                log_probs, num_beams)

            for i in range(num_beams):
                new_seq = torch.cat(
                    [seq, top_ids[:, i:i+1]], dim=-1)
                new_score = score + top_log_probs[0, i].item()
                all_candidates.append((new_seq, new_score))

        # Keep only the best beams
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:num_beams]

    best_seq = beams[0][0]
    return tokenizer.decode(best_seq[0])

With 4 beams, you're tracking 4 parallel hypotheses at every step. Each hypothesis gets expanded by the top 4 next tokens, giving 16 candidates, and you prune back down to 4. This is dramatically better than greedy at finding good overall sequences rather than just locally good tokens.

Beam search is the standard for machine translation and text summarization -- tasks where accuracy matters more than creativity. Having said that, beam search output tends to be bland and generic. It optimizes for probability, and the most probable text is often the most boring. "The weather is nice today" has high probability in almost any context. For creative applications, you want something else entirely: controlled randomness.

Sampling with temperature: the creativity knob

In stead of always picking the most likely token, what if we sample from the probability distribution? Higher-probability tokens are still more likely to be selected, but lower-probability tokens get a chance too. This introduces variation -- run the same prompt twice and you get different outputs.

def sample_decode(model, tokenizer, prompt, temperature=1.0,
                  max_tokens=100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    for _ in range(max_tokens):
        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :]

        # Apply temperature scaling
        scaled_logits = logits / temperature

        # Convert to probabilities and sample
        probs = F.softmax(scaled_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0])

Temperature is the single most important generation parameter, and it's often misunderstood. Here's what it does mathematically: it divides the raw logits before the softmax. That one operation changes everything about the output distribution:

  • Temperature = 1.0: the original distribution. Nothing changes.
  • Temperature < 1.0 (say, 0.3): the logits get divided by a small number, making them larger in magnitude. After softmax, the peaks get even peakier. High-probability tokens dominate even more. Output is focused and predictable.
  • Temperature > 1.0 (say, 1.5): the logits get divided by a large number, flattening them toward zero. After softmax, the distribution becomes more uniform. Rare tokens get a bigger chance. Output is creative but potentially incoherent.
  • Temperature -> 0: approaches greedy decoding. The highest-probability token takes over completely.

Let me make this concrete with actual numbers:

# Temperature effects on a real distribution
logits = torch.tensor([2.0, 1.0, 0.5, -1.0])

for temp in [0.1, 0.5, 1.0, 2.0]:
    probs = F.softmax(logits / temp, dim=-1)
    print(f"T={temp}: {probs.numpy().round(3)}")
# T=0.1: [1.    0.    0.    0.   ]   <- almost greedy
# T=0.5: [0.867 0.117 0.015 0.001]   <- focused
# T=1.0: [0.563 0.207 0.126 0.028]   <- balanced
# T=2.0: [0.392 0.252 0.197 0.088]   <- spread out

See the pattern? At T=0.1, the first token has essentially 100% of the probability mass -- you'll always pick it. At T=2.0, even the least likely token (0.088 = 8.8%) has a reasonable shot. The range between 0.5 and 1.0 is where most practical applications live.

If you remember softmax from episode #38 (neural networks forward pass) and episode #51 (attention mechanisms), you already understand the math here. Temperature scaling is just a pre-processing step that modifies how "sharp" or "flat" the softmax output is. Same function, different input scale, dramatically different behavior ;-)

Top-k sampling: cutting the tail

Pure sampling has a problem even with reasonable temperature: extremely unlikely tokens still have a non-zero chance of being picked. A token with 0.001% probability will eventually appear if you generate enough text, and that one bad token can derail the entire sequence into nonsense.

Top-k sampling restricts the candidate pool to only the k most probable tokens:

def top_k_sample(logits, k=50, temperature=1.0):
    scaled = logits / temperature

    # Find the top-k values and their indices
    top_k_values, top_k_indices = torch.topk(scaled, k)

    # Set everything else to -infinity (zero probability after softmax)
    filtered = torch.full_like(scaled, float('-inf'))
    filtered.scatter_(1, top_k_indices, top_k_values)

    probs = F.softmax(filtered, dim=-1)
    return torch.multinomial(probs, num_samples=1)

With k=50, only the 50 most likely next tokens are considered. Everything else is masked out completely -- zero chance of selection. This eliminates the low-probability noise while keeping diversity among the reasonable candidates.

But k=50 is a fixed number, and that's the problem. Consider two situations:

Situation A: The model is very confident. One token has 95% probability, and the remaining 49 tokens share the other 5%. Here k=50 allows 49 tokens the model thinks are almost certainly wrong.

Situation B: The model is uncertain. 200 tokens each have roughly 0.5% probability. Here k=50 cuts off 150 perfectly reasonable choices, artificially restricting the output.

A fixed k can't handle both situations well. We need the candidate pool to adapt to the model's confidence.

Top-p (nucleus) sampling: the adaptive solution

Top-p sampling (Holtzman et al., 2019) elegantly solves the fixed-k problem. In stead of picking a fixed number of candidates, it includes the smallest set of tokens whose cumulative probability reaches p:

def top_p_sample(logits, p=0.9, temperature=1.0):
    scaled = logits / temperature
    probs = F.softmax(scaled, dim=-1)

    # Sort by probability (highest first)
    sorted_probs, sorted_indices = torch.sort(
        probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # Find where cumulative probability exceeds p
    # Keep the first token that pushes us over the threshold
    sorted_mask = cumulative_probs - sorted_probs > p
    sorted_probs[sorted_mask] = 0.0

    # Renormalize so they sum to 1 again
    sorted_probs = sorted_probs / sorted_probs.sum()

    # Sample from the filtered distribution
    token_idx = torch.multinomial(sorted_probs, num_samples=1)
    return sorted_indices.gather(-1, token_idx)

With p=0.9: if one token has 95% probability, only that token is in the nucleus (the name comes from thinking of the high-probability core of the distribution). If 100 tokens each have ~1% probability, all 100 are included. The candidate pool automatically adapts to how confident the model is at each step.

This is a really elegant idea when you think about it. The nucleus grows and shrinks based on what the model actually believes, not on an arbitrary fixed threshold. In practice, most LLM APIs combine temperature with top-p -- the typical defaults (temperature 0.7, top_p 0.9) produce coherent, slightly creative text. For factual tasks, lower both. For creative writing, raise temperature a bit.

def combined_sampling(logits, temperature=0.7, top_k=50,
                      top_p=0.9):
    """Full generation pipeline: temperature -> top-k -> top-p."""
    # Step 1: temperature scaling
    scaled = logits / temperature

    # Step 2: top-k filtering (optional, set k=0 to skip)
    if top_k > 0:
        top_k_vals, top_k_idx = torch.topk(scaled, top_k)
        filtered = torch.full_like(scaled, float('-inf'))
        filtered.scatter_(1, top_k_idx, top_k_vals)
    else:
        filtered = scaled

    # Step 3: top-p filtering
    probs = F.softmax(filtered, dim=-1)
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=-1)
    mask = cumsum - sorted_probs > top_p
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / sorted_probs.sum()

    # Step 4: sample
    token_idx = torch.multinomial(sorted_probs, num_samples=1)
    return sorted_idx.gather(-1, token_idx)

Note the order matters: temperature first (scales the logits), then top-k (removes improbable candidates), then top-p (adapts the remaining pool). This is how most production systems apply these parameters.

Repetition penalty: breaking the loop

Even with good sampling settings, models tend to fall into repetitive patterns. You've probably seen this: the model starts saying "the key advantage is" and then three paragraphs later it says "the key advantage is" again and again with slight variations. Repetition penalty (Keskar et al., 2019) discourages tokens that have already appeared in the generated text:

def apply_repetition_penalty(logits, generated_ids, penalty=1.2):
    """Penalize tokens that already appeared in the output."""
    for token_id in set(generated_ids):
        if logits[0, token_id] > 0:
            logits[0, token_id] /= penalty  # reduce positive logit
        else:
            logits[0, token_id] *= penalty  # make negative more negative
    return logits

A penalty of 1.0 means no change. A penalty of 1.2 reduces the score of already-seen tokens by about 20%. The asymmetric handling (divide positive logits, multiply negative ones) is important -- in both cases the token becomes less likely after softmax.

Values above 1.5 can make the output aggressively avoid any repitition, sometimes to the point of incoherence as the model runs out of natural word choices. Common words like "the", "is", "and" get penalized even though they should repeat -- they're function words, not content repetition.

Frequency penalty is a smarter variant. In stead of applying a flat penalty to any repeated token, the penalty scales with how many times the token has appeared:

def apply_frequency_penalty(logits, token_counts, penalty=0.5):
    """Penalty scales with how often each token has appeared."""
    for token_id, count in token_counts.items():
        logits[0, token_id] -= penalty * count
    return logits


def generate_with_frequency_penalty(model, tokenizer, prompt,
                                    penalty=0.5, max_tokens=100):
    """Full generation loop with frequency tracking."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    token_counts = {}

    for _ in range(max_tokens):
        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :]

        # Apply frequency-based penalty
        logits = apply_frequency_penalty(logits, token_counts,
                                         penalty)

        # Sample with temperature
        probs = F.softmax(logits / 0.7, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

        # Update frequency count
        tid = next_token.item()
        token_counts[tid] = token_counts.get(tid, 0) + 1

        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0])

A word used once gets a small penalty. A word used ten times gets a penalty ten times larger. This allows natural repetition of common function words while progressively discouraging content words that keep showing up unnaturally.

Speculative decoding: speed without sacrifice

Here's a different kind of problem. Standard autoregressive generation is inherently sequential -- each token requires a full forward pass through the model, and each pass depends on the previous token. For a 70B parameter model, generating 100 tokens means 100 sequential forward passes. That's slow.

Speculative decoding (Leviathan et al., 2022) is a genuinely clever trick. Use a smaller, faster draft model to propose multiple tokens at once, then verify them with the large target model in a single forward pass:

The algorithm:
1. Draft model (fast, ~1B params) generates 5 tokens: [A, B, C, D, E]
2. Target model (slow, ~70B params) processes all 5 at once
3. Target model accepts [A, B, C] but rejects D
4. Keep [A, B, C], resample position 4 from the target model
5. Repeat from position 4

The key insight: verifying N tokens in parallel takes roughly the same time as generating 1 token. A transformer processes all positions simultaneously in one forward pass (remember the parallelism advantage from episode #52?). So if the draft model's guesses match the target model often enough -- and for a well-chosen draft model, they match 60-80% of the time -- you get 2-3x speedup with mathematically identical output distribution. Not "similar" output -- identical. The acceptance/rejection scheme guarantees this.

def speculative_decode(target_model, draft_model, tokenizer,
                       prompt, n_speculative=5, max_tokens=100):
    """Speculative decoding: draft proposes, target verifies."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    generated = 0

    while generated < max_tokens:
        # Phase 1: draft model generates n_speculative tokens
        draft_ids = input_ids.clone()
        draft_probs_list = []

        for _ in range(n_speculative):
            with torch.no_grad():
                logits = draft_model(draft_ids).logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_tok = torch.multinomial(probs, num_samples=1)
            draft_probs_list.append(probs)
            draft_ids = torch.cat([draft_ids, next_tok], dim=-1)

        # Phase 2: target model scores all positions at once
        with torch.no_grad():
            target_logits = target_model(draft_ids).logits

        # Phase 3: accept/reject each drafted token
        n_input = input_ids.shape[1]
        n_accepted = 0

        for i in range(n_speculative):
            pos = n_input + i
            target_probs = F.softmax(
                target_logits[:, pos - 1, :], dim=-1)
            draft_token = draft_ids[0, pos].item()
            draft_p = draft_probs_list[i][0, draft_token].item()
            target_p = target_probs[0, draft_token].item()

            # Accept if target agrees (or with probability ratio)
            if target_p >= draft_p:
                n_accepted += 1
            else:
                # Probabilistic acceptance
                if torch.rand(1).item() < target_p / draft_p:
                    n_accepted += 1
                else:
                    break

        # Keep accepted tokens + sample one from target
        input_ids = draft_ids[:, :n_input + n_accepted]

        # Sample the next token from target model
        target_probs = F.softmax(
            target_logits[:, n_input + n_accepted - 1, :],
            dim=-1)
        correction = torch.multinomial(target_probs,
                                       num_samples=1)
        input_ids = torch.cat([input_ids, correction], dim=-1)
        generated += n_accepted + 1

        if tokenizer.eos_token_id in input_ids[0, n_input:]:
            break

    return tokenizer.decode(input_ids[0])

The probabilistic acceptance (target_p / draft_p) is what makes this mathematically exact. If the draft model says token X has probability 0.3 and the target says 0.2, we accept with probability 0.2/0.3 = 67%. This rejection sampling approach guarantees that the final distribution matches the target model exactly, regardless of how bad the draft model is. A bad draft model just means more rejections (lower speedup), not worse output quality.

Constrained generation: following the rules

Sometimes you don't just want plausible text -- you need the output to follow a specific format. Valid JSON. A number between 1 and 10. One of three predefined categories. A SQL query. Standard sampling makes no format guarantees whatsoever.

Grammar-guided generation solves this by restricting the model at each step to only produce tokens that are valid according to a formal grammar:

def grammar_constrained_generate(model, tokenizer, prompt,
                                 grammar, max_tokens=200):
    """Generate text that conforms to a formal grammar."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    parser_state = grammar.initial_state()

    for _ in range(max_tokens):
        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :]

        # Ask the grammar which tokens are legal right now
        allowed_tokens = grammar.allowed_tokens(parser_state)

        # Mask out everything else
        mask = torch.full_like(logits, float('-inf'))
        mask[0, allowed_tokens] = 0
        constrained_logits = logits + mask

        # Sample from allowed tokens only
        probs = F.softmax(constrained_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Advance the grammar parser
        parser_state = grammar.advance(
            parser_state, next_token.item())
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        if grammar.is_complete(parser_state):
            break

    return tokenizer.decode(input_ids[0])

The model still has creative freedom within the grammar's constraints -- it decides what values to use, how to structure the content, what words to pick. But the structure is guaranteed. Every opening brace gets a closing brace. Every key is a quoted string. Every number is actually a number.

Tools like llama.cpp's GBNF grammars and Outlines (a Python library) implement this efficiently. You define a grammar -- JSON schema, regex pattern, or context-free grammar -- and the engine guarantees conformance.

Back in episode #62 we discussed structured outputs from LLM APIs. Constrained generation is the mechanism that makes that possible under the hood. When you tell an API "respond in JSON matching this schema," the server is running something like the code above (or a more optimized version of it).

Let's implement a simple JSON grammar constraint to show how it works in practice:

class SimpleJsonGrammar:
    """Minimal JSON grammar for demonstration."""

    def __init__(self, required_keys):
        self.required_keys = required_keys
        self.states = [
            "start",       # expecting {
            "key",         # expecting "key_name"
            "colon",       # expecting :
            "value",       # expecting a value
            "comma_or_end" # expecting , or }
        ]

    def initial_state(self):
        return {
            "state": "start",
            "keys_seen": [],
            "depth": 0,
        }

    def allowed_tokens(self, state):
        """Return token IDs that are valid in current state."""
        # This is simplified -- real implementations use
        # the tokenizer's full vocabulary mapped to grammar
        # terminals
        s = state["state"]
        if s == "start":
            return ["{"]
        elif s == "key":
            remaining = [k for k in self.required_keys
                         if k not in state["keys_seen"]]
            return [f'"{k}"' for k in remaining]
        elif s == "colon":
            return [":"]
        elif s == "value":
            return ["string", "number", "true", "false", "null"]
        elif s == "comma_or_end":
            if len(state["keys_seen"]) < len(self.required_keys):
                return [","]
            return ["}"]
        return []

    def advance(self, state, token):
        """Move the grammar state machine forward."""
        new_state = dict(state)
        new_state["keys_seen"] = list(state["keys_seen"])
        s = state["state"]

        if s == "start":
            new_state["state"] = "key"
            new_state["depth"] = 1
        elif s == "key":
            new_state["keys_seen"].append(token)
            new_state["state"] = "colon"
        elif s == "colon":
            new_state["state"] = "value"
        elif s == "value":
            new_state["state"] = "comma_or_end"
        elif s == "comma_or_end":
            if token == ",":
                new_state["state"] = "key"
            else:
                new_state["depth"] = 0
        return new_state

    def is_complete(self, state):
        return state["depth"] == 0 and len(state["keys_seen"]) > 0


# Demo
grammar = SimpleJsonGrammar(
    required_keys=["name", "score", "passed"])
state = grammar.initial_state()
print(f"Start -> allowed: {grammar.allowed_tokens(state)}")
state = grammar.advance(state, "{")
print(f'After {{ -> allowed: {grammar.allowed_tokens(state)}')
state = grammar.advance(state, '"name"')
print(f'After "name" -> allowed: {grammar.allowed_tokens(state)}')

This is a simplified illustration. Real grammar engines operate at the token level (not the string level), handling the messy reality that a single word might span multiple tokens and a single token might be part of multiple grammar productions. But the principle is identical: maintain a parser state, filter the vocabulary at each step, let the model fill in the content.

Putting it all together: a complete generation pipeline

Let me show you a class that combines everything we've covered into a single reusable pipeline. This is the kind of code you'd actually use in a project:

import torch
import torch.nn.functional as F
from collections import Counter


class TextGenerator:
    """Complete text generation pipeline with all strategies."""

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, prompt, max_tokens=200, strategy="sample",
                 temperature=0.7, top_k=50, top_p=0.9,
                 repetition_penalty=1.0, frequency_penalty=0.0,
                 num_beams=4):
        """Generate text with the specified strategy."""
        input_ids = self.tokenizer.encode(
            prompt, return_tensors="pt")

        if strategy == "greedy":
            return self._greedy(input_ids, max_tokens)
        elif strategy == "beam":
            return self._beam(input_ids, max_tokens, num_beams)
        elif strategy == "sample":
            return self._sample(
                input_ids, max_tokens, temperature, top_k,
                top_p, repetition_penalty, frequency_penalty)
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

    def _greedy(self, input_ids, max_tokens):
        for _ in range(max_tokens):
            logits = self._get_logits(input_ids)
            next_token = torch.argmax(logits, dim=-1,
                                      keepdim=True)
            if next_token.item() == self.tokenizer.eos_token_id:
                break
            input_ids = torch.cat([input_ids, next_token],
                                  dim=-1)
        return self.tokenizer.decode(input_ids[0])

    def _beam(self, input_ids, max_tokens, num_beams):
        beams = [(input_ids, 0.0)]
        for _ in range(max_tokens):
            candidates = []
            for seq, score in beams:
                logits = self._get_logits(seq)
                log_probs = F.log_softmax(logits, dim=-1)
                top_lp, top_ids = torch.topk(
                    log_probs, num_beams)
                for i in range(num_beams):
                    new_seq = torch.cat(
                        [seq, top_ids[:, i:i+1]], dim=-1)
                    new_score = score + top_lp[0, i].item()
                    candidates.append((new_seq, new_score))
            candidates.sort(key=lambda x: x[1], reverse=True)
            beams = candidates[:num_beams]
        return self.tokenizer.decode(beams[0][0][0])

    def _sample(self, input_ids, max_tokens, temperature,
                top_k, top_p, rep_penalty, freq_penalty):
        token_counts = Counter()
        generated_ids = []

        for _ in range(max_tokens):
            logits = self._get_logits(input_ids)

            # Repetition penalty
            if rep_penalty != 1.0 and generated_ids:
                for tid in set(generated_ids):
                    if logits[0, tid] > 0:
                        logits[0, tid] /= rep_penalty
                    else:
                        logits[0, tid] *= rep_penalty

            # Frequency penalty
            if freq_penalty > 0:
                for tid, count in token_counts.items():
                    logits[0, tid] -= freq_penalty * count

            # Temperature
            scaled = logits / temperature

            # Top-k
            if top_k > 0:
                kth_val = torch.topk(scaled, top_k).values[
                    0, -1]
                scaled[scaled < kth_val] = float('-inf')

            # Top-p
            probs = F.softmax(scaled, dim=-1)
            sorted_p, sorted_i = torch.sort(
                probs, descending=True)
            cumsum = torch.cumsum(sorted_p, dim=-1)
            mask = cumsum - sorted_p > top_p
            sorted_p[mask] = 0.0
            sorted_p = sorted_p / sorted_p.sum()

            idx = torch.multinomial(sorted_p, num_samples=1)
            next_token = sorted_i.gather(-1, idx)

            if next_token.item() == self.tokenizer.eos_token_id:
                break

            tid = next_token.item()
            generated_ids.append(tid)
            token_counts[tid] += 1
            input_ids = torch.cat([input_ids, next_token],
                                  dim=-1)

        return self.tokenizer.decode(input_ids[0])

    def _get_logits(self, input_ids):
        with torch.no_grad():
            return self.model(input_ids).logits[:, -1, :]
# Usage demonstration (requires a loaded model)
# gen = TextGenerator(model, tokenizer)
#
# # Factual output
# gen.generate("Explain gravity:", strategy="greedy")
#
# # Creative output
# gen.generate("Once upon a time,", strategy="sample",
#              temperature=0.9, top_p=0.95)
#
# # Structured output (beam search)
# gen.generate("Translate to French:", strategy="beam",
#              num_beams=5)
#
# # Anti-repetition
# gen.generate("Write a paragraph about AI:",
#              strategy="sample", repetition_penalty=1.3,
#              frequency_penalty=0.5)

# Parameter guide
params = [
    ("temperature", "0.1-0.3", "0.7-0.9", "1.0-1.5"),
    ("top_k",       "10-20",   "40-60",   "80-200"),
    ("top_p",       "0.5-0.7", "0.85-0.95","0.95-1.0"),
    ("rep_penalty", "1.0",     "1.1-1.2",  "1.3-1.5"),
    ("freq_penalty","0.0",     "0.3-0.5",  "0.8-1.0"),
]

print(f"{'Parameter':<15} {'Factual':>12} {'Balanced':>12} "
      f"{'Creative':>12}")
print("-" * 55)
for name, low, mid, high in params:
    print(f"{name:<15} {low:>12} {mid:>12} {high:>12}")

Key takeaways

  • Greedy decoding always picks the highest-probability token. Fast and deterministic, but produces repetitive, bland output because it never explores alternatives;
  • beam search maintains multiple candidate sequences and finds globally better results -- the standard for translation and summarization, but too boring for creative tasks;
  • temperature is the master creativity control. Low values (0.1-0.3) for factual precision, high values (0.9-1.5) for creative variation. It scales the logits before softmax, sharpening or flattening the distribution;
  • top-k limits candidates to the k most likely tokens, cutting long-tail noise. But k is fixed, which is suboptimal when model confidence varies;
  • top-p (nucleus sampling) is the adaptive answer: it includes the smallest set of tokens whose cumulative probability reaches p. Generally preferred over top-k because it adapts to the distribution shape;
  • repetition and frequency penalties prevent the model from looping. Repetition penalty applies a flat discount to seen tokens; frequency penalty scales with how often each token appeared -- the latter is usually better behaved;
  • speculative decoding uses a small draft model to propose tokens verified by a larger model, achieving 2-3x speedup with mathematically identical output quality -- the verification is exact, not approximate;
  • constrained generation with formal grammars guarantees structured output (valid JSON, specific formats) while preserving the model's freedom within those constraints. This is what powers the "structured output" features in modern LLM APIs.

Exercises

Exercise 1: Build a temperature visualizer. Create a function visualize_temperature(logits, temperatures) that takes a 1D tensor of logits (at least 10 values) and a list of temperatures (e.g., [0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0]). For each temperature: apply the scaling, compute softmax probabilities, and calculate three metrics: (a) entropy of the distribution (higher = more random), (b) probability of the top token (how dominant the best choice is), (c) effective vocabulary size (number of tokens with probability > 1% of the total). Print a table with all three metrics for each temperature. Then add a recommend_temperature(task_type) function that returns the recommended temperature range for "factual", "balanced", and "creative" tasks based on entropy thresholds.

Exercise 2: Build a repetition detector and auto-penalizer. Create a class RepetitionGuard that monitors generated text for repetitive patterns. It should track: (a) individual token frequencies, (b) bigram (2-token) frequencies, and (c) the longest repeated substring in the last N tokens. Implement check(generated_ids) that returns a dictionary with {"token_repeats": [...], "bigram_repeats": [...], "longest_repeat": int, "severity": float} where severity is 0.0 (no repetition) to 1.0 (severe). Implement adaptive_penalty(severity) that maps severity to a repetition penalty value (1.0 at severity 0.0, scaling up to 1.5 at severity 1.0). Test with a simulated token sequence that has intentional repetition: generate 100 token IDs where tokens 50-70 are a repeating pattern of [10, 20, 30, 10, 20, 30, ...] and the rest are random. Show that the guard detects the repetition and recommends an appropriate penalty.

Exercise 3: Build a decoding strategy comparator. Create a class DecodingComparator that takes a probability distribution (a tensor of probabilities over a vocabulary) and simulates N samples from it using different strategies: greedy (1 sample, always argmax), pure sampling, top-k (k=10, 20, 50), and top-p (p=0.5, 0.9, 0.95). For each strategy, run 1000 sampling trials and measure: (a) how many unique tokens were selected across all trials, (b) the Shannon entropy of the empirical sample distribution, and (c) whether the top-1 token from the original distribution was selected at least once. Use a synthetic distribution: 1000 tokens where the probabilities follow a Zipf distribution (token i has probability proportional to 1/i). Print a comparison table. Also compute the KL divergence between the original distribution and each strategy's empirical distribution to quantify how much each strategy distorts the model's beliefs.

Bedankt en tot de volgende keer!

@scipio