ML Bible

ML Bible · Chapter 5

Transformers

Attention Is All You Need: from RNNs and LSTMs through self-attention, the full transformer, modern upgrades (RoPE, GQA, MoE, FlashAttention), training, LoRA, and GANs/VAEs.

In 2017, eight researchers at Google published a paper called "Attention Is All You Need," and machine learning was never the same. The architecture they introduced — the Transformer — is the foundation of GPT, Claude, BERT, Gemini, LLaMA, and basically every modern foundation model. If you want to understand modern ML, this is the thing you have to understand.

So here's the plan. We're going to take this from first principles. We'll start high-level — asking why the transformer had to exist at all — then go deep, building the architecture up one component at a time, with diagrams to make the math visual. Along the way you'll meet every idea that lives inside the transformer: memory, gates, attention, parallelism. None of those ideas were invented in 2017. They were each invented to fix a specific problem with the model that came before, and the transformer is what you get when you keep all the good parts and drop the parts that slowed everything down.

Take it slow. By the end you'll be able to read any modern LLM paper and know exactly which piece it's poking at.

A quick note on what I'm assuming: you already know what a neural network is and roughly how a GPU works. Everything else, I'll explain as we go.

Let's get into it.

History

For a long time, the models we used for language were good at recognizing fixed patterns but fell apart the moment you handed them a sequence — and sequences are exactly what separate human language from a pile of features. These early models had no memory of order or context. Ring a bell? It should, because every model we're about to walk through is one more attempt to fix that single problem, and the transformer is where it finally gets fixed properly.

Here's the lineage we're going to trace, in order:

Feedforward → RNN → LSTM/GRU → Seq2Seq → Attention → Transformer.

Every arrow in that chain is a specific limitation that the next model was invented to solve. Hold that framing in your head the whole way through — it's the entire story, and it's what makes the transformer feel inevitable rather than magical.

1. Feedforward Neural Networks (1950s–1980s)

These are the familiar networks — the kind you already know. We'll use them as our starting line.

A feedforward network (FFN) is the most basic neural network there is. Information flows in exactly one direction: input → hidden layers → output. No loops, no memory, no notion of time. Each layer is a matrix multiplication followed by a nonlinearity:

h=σ(Wx+b)h = \sigma(Wx + b)

Let's read every symbol in that:

  • xx — the input vector you feed in.
  • WW — the weight matrix for the layer. This is what the network learns.
  • bb — the bias vector, a learnable offset.
  • σ\sigma — a nonlinear activation function (something like ReLU or sigmoid) applied element by element.
  • hh — the resulting hidden representation that gets passed to the next layer.

You feed in a fixed-size input vector xx, the network runs it through a series of these transformations, and out comes a fixed-size output vector. Train it on enough examples and it can approximate basically any function from inputs to outputs — the universal approximation theorem says so.

For images, you'd flatten the pixels into one long vector. For tabular data, each column becomes an input dimension. For text, you'd... well, that's where the trouble started.

Why they were a breakthrough. In the 1980s, feedforward networks — trained with backpropagation, popularized by Rumelhart, Hinton, and Williams in 1986 — showed that neural networks could learn complicated functions automatically, straight from data, with nobody hand-coding the rules. They worked great on problems with fixed-size inputs and outputs: digit recognition, simple classification, regression.

How they handled sequences. Badly. And here's the root of it: a feedforward network has a fixed input size. To process "the cat sat," you'd have to pick a window size up front — say five words — and feed in five word embeddings glued together.

That choice immediately boxes you in three different ways:

First, fixed length. A five-word window can't handle a six-word sentence, let alone a paragraph. You're stuck truncating or padding everything to fit.

Second, no real sense of order. The network treats "word at position 1" and "word at position 2" as completely separate features. There's no shared understanding that they're the same kind of thing showing up in different spots. Trying to learn grammar this way is painfully inefficient — the network has to relearn what a verb is at every position separately.

Third, bag-of-words behavior. In a lot of setups, researchers used simpler representations like averaging the word vectors together. Do that and the network literally can't tell "dog bites man" from "man bites dog." Order just evaporates.

For anything sequential — language, audio, time series — feedforward networks were a dead end. You couldn't even decide what the right input format was supposed to be. This is what pushed the whole field toward architectures that had memory and some awareness of order.

Feedforward Neural Networkh = σ(Wx + b)inputhidden 1hidden 2outputx1x2x3x4y1y2
Hover a node to trace its connections. Hover the formula for term definitions.

Information flows one way. No memory, fixed input size.

Fig 5.1 — Feedforward Neural Network

Check your understanding

Why can't a plain feedforward network handle a sentence of arbitrary length?

Show answer ▸

Because its input size is fixed at build time. You have to commit to a window (say five words) up front, so a six-word sentence won't fit and a three-word one has to be padded. On top of that, it treats each position as an unrelated feature, so it has no built-in notion that a word at position 2 is "the same kind of thing" as a word at position 1 — which makes learning order and grammar wildly inefficient.

2. Recurrent Neural Networks (1980s)

As the name suggests, RNNs are networks that feed their own output back in as input at the next step. The network keeps a hidden state hh that carries across timesteps. At each step tt:

ht=tanh(Wxxt+Whht1+b)h_t = \tanh(W_x x_t + W_h h_{t-1} + b)
yt=Wyhty_t = W_y h_t

Reading the symbols:

  • xtx_t — the input at the current timestep (e.g., the current word).
  • ht1h_{t-1} — the hidden state from the previous step. This is the memory.
  • hth_t — the new hidden state, a running summary of everything seen so far.
  • WxW_x — weights applied to the current input.
  • WhW_h — weights applied to the previous hidden state.
  • WyW_y — weights that turn the hidden state into an output yty_t.
  • bb — a bias term; tanh\tanh is the squashing nonlinearity.

The hidden state is a "running summary" of everything the model has read up to now. And here's the most important part of the whole setup: the same weights are reused at every single timestep. There's only one WxW_x, one WhW_h, one WyW_y, no matter how long the sequence is. That weight-sharing is exactly what lets an RNN handle sequences of any length — you just keep applying the same transformation as new inputs roll in.

Let's walk an example. To process "the cat sat," you'd:

1. Feed in "the," get h1h_1.

2. Feed in "cat" along with h1h_1, get h2h_2.

3. Feed in "sat" along with h2h_2, get h3h_3.

By the end, h3h_3 is supposed to be a summary of the entire sequence.

This was a genuine leap. RNNs gave neural networks memory for the first time. In principle the hidden state could carry information from arbitrarily far back. They handled variable-length sequences naturally. And they produced an output at every timestep, so they could do part-of-speech tagging, language modeling (predict the next word), or translation. For roughly two decades, RNNs were the standard architecture for any sequence problem in deep learning.

Here's the unrolled view — the same cell repeated across time:

RNN Unrolled Through TimeRNNRNN cell(shared Wx,Wh,Wy)loops on h=running summary of everything seen so farh1running summary of everything seen so farh2t=1y1h1thet=2y2h2catt=3y3h3satSame weights reused at every step — that's what handles any length.
no timesteps lit yet
Fig 5.2 — Same weights reused at every step — that's what handles any length.

So why did RNNs fade out? One word: gradients. In theory the hidden state could carry information arbitrarily far. In practice it couldn't, and there were two compounding reasons plus a third structural one.

Vanishing gradients. To train an RNN you backpropagate the gradient through every timestep — this is called backpropagation through time. The gradient at step 1, coming from a loss at step 50, gets multiplied by the recurrent weight matrix 49 times on the way back. If those multiplications shrink the signal even a little — which is the default behavior with sigmoid or tanh nonlinearities — the gradient shrinks exponentially toward zero. By the time it reaches the early steps it's a rounding error. The early timesteps simply can't learn from mistakes made later on.

Exploding gradients. The mirror image. If the recurrent weights are a touch too large, the gradient grows exponentially through the backward pass and the loss blows up to NaN (Not a Number). It's less common than vanishing but harder to ignore when it hits. The standard fix became gradient clipping — capping the gradient's norm so it can't run away.

Compressed state. Even if the gradient flowed perfectly, the entire history has to be crammed into one hidden-state vector — a few hundred numbers. There just isn't enough room to remember everything, so new inputs end up overwriting old information.

In practice, RNNs reliably learned dependencies of about 5–10 tokens. Past that, they degraded. For language that's brutal. Understanding a sentence like "the dog that the cat that the rat bit chased ran" means linking words that sit far apart, and a plain RNN simply couldn't hold the thread that long. This is what sent everyone searching for an architecture with better memory.

Check your understanding

The vanishing-gradient problem keeps coming back in this guide. In an RNN, what specifically causes it?

Show answer ▸

Backpropagation through time multiplies the gradient by the recurrent weight matrix once per timestep on the way back. With tanh/sigmoid nonlinearities those repeated multiplications tend to shrink the signal, so over many steps the gradient decays exponentially toward zero. The early timesteps therefore receive almost no learning signal from later losses. Keep this villain in mind — LSTMs, scaling in attention, residual connections, and LayerNorm are all partly about beating it.

3. LSTM (1997) and GRU (2014)

An LSTM (Long Short-Term Memory network) was designed to fix that exact vanishing-gradient problem. The core difference from a plain RNN is that it gives the network a separate memory cell that information can flow through with almost no interference, controlled by learnable gates.

An LSTM cell keeps two pieces of state at each step, not one:

  • The hidden state hth_t — like an RNN's.
  • The cell state CtC_t — a separate "memory highway" running straight through time.

And it has three gates deciding what happens to that cell state at each step:

1. Forget gate ftf_t — decides what to erase from the cell state. It's a sigmoid, so it outputs values between 0 (forget completely) and 1 (keep entirely).

2. Input gate iti_t — decides what new information to write into the cell state.

3. Output gate oto_t — decides what to read out of the cell state to form the hidden state.

Here are the equations:

ft=σ(Wf[ht1,xt])f_t = \sigma(W_f [h_{t-1}, x_t])
it=σ(Wi[ht1,xt])i_t = \sigma(W_i [h_{t-1}, x_t])
C~t=tanh(WC[ht1,xt])\tilde{C}_t = \tanh(W_C [h_{t-1}, x_t])
Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
ot=σ(Wo[ht1,xt])o_t = \sigma(W_o [h_{t-1}, x_t])
ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)

Let's name every symbol:

  • [ht1,xt][h_{t-1}, x_t] — the previous hidden state and current input, concatenated into one vector.
  • Wf,Wi,WC,WoW_f, W_i, W_C, W_o — learnable weight matrices for the forget gate, input gate, candidate, and output gate respectively.
  • σ\sigma — the sigmoid function, squashing to (0,1)(0,1) — perfect for a gate, since it acts like a soft on/off dial.
  • C~t\tilde{C}_t — the candidate new memory, the fresh information that might get written in.
  • \odot — element-wise multiplication (the gates act like dimmer switches on each dimension).
  • CtC_t — the updated cell state; hth_t — the updated hidden state.

The line that does all the work is the cell-state update:

Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t

Look at what it's doing. The old cell state Ct1C_{t-1} is multiplied by the forget gate (which can sit near 1, leaving memory basically untouched) and then added to. There's no repeated matrix multiplication grinding the signal down. That additive path is why gradients can flow back across many timesteps without vanishing. This is the famous "memory highway": when the forget gate is open and the input gate is closed, information from far in the past slides through unchanged.

Here's the cell, gate by gate:

LSTM Cellmemory highway · CC(t-1)×+f_tσi_tσC~_ttanho_tσf_t×i·C~tanh×C_th_tconcat [h(t-1), x_t]h(t-1)x_tC_t = f_t · C(t-1) + i_t · C~_tclick a gate for its equation, or open the highway to see memory pass through.
Fig 5.3 — The additive cell-state path is what lets gradients survive across many steps.

The Gated Recurrent Unit (GRU)

The GRU is a simpler take on the same idea. It merges the cell state and hidden state into one, and uses two gates instead of three:

1. Update gate ztz_t — combines the jobs of the input and forget gates. It decides how much to update versus how much to preserve.

2. Reset gate rtr_t — controls how much of the past to use when computing the new candidate state.

zt=σ(Wz[ht1,xt])z_t = \sigma(W_z [h_{t-1}, x_t])
rt=σ(Wr[ht1,xt])r_t = \sigma(W_r [h_{t-1}, x_t])
h~t=tanh(W[rtht1,xt])\tilde{h}_t = \tanh(W [r_t \odot h_{t-1}, x_t])
ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t

The symbols echo the LSTM: ztz_t and rtr_t are the gates, WzW_z, WrW_r, WW are their learnable matrices, h~t\tilde{h}_t is the candidate hidden state, and \odot is again element-wise multiplication. Notice the final line is another convex blend — (1zt)(1 - z_t) of the old state plus ztz_t of the new candidate — which keeps that same gradient-friendly additive flavor with fewer moving parts. Fewer gates means fewer parameters and faster training, often at basically the same quality.

Check your understanding

What single design choice lets an LSTM carry information across many timesteps where a plain RNN can't?

Show answer ▸

The separate cell state with its additive update, Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t. Because the old memory is gated (multiplied by something near 1) and then added to — rather than pushed through a fresh matrix multiply every step — there's a near-uninterrupted path for both information and gradients to travel down. That's the "memory highway."

4. The limitations that finally retired recurrence

LSTMs and GRUs were a huge step up. But even at their best, two problems stuck around — and a third one in translation setups turned out to be the spark for everything that followed.

Sequential training is slow. Computing hth_t requires ht1h_{t-1}, which requires ht2h_{t-2}, all the way back to the start. You cannot parallelize across time. On a GPU with thousands of cores, an LSTM uses roughly one core's worth of work per timestep. For a 1,000-token sequence, that's 1,000 operations that have to happen one after another. Modern GPUs scream through big dense matrix multiplications, but an LSTM can't feed them that kind of work. Training a large LSTM on a long sequence took days where a transformer does the same job in hours.

Long-range dependencies still degrade. LSTMs were far better than plain RNNs, but they still struggled past a few hundred tokens. The cell state has finite capacity, so over a long sequence the useful early information gets overwritten or muddied. Empirically, performance on long-range tasks just plateaued.

The information squeeze. This one was especially painful in encoder–decoder LSTM setups for translation. The encoder had to compress the entire source sentence into one fixed-size vector before the decoder could even start generating. For a 30-word sentence, maybe that vector could hold it all. For a 100-word paragraph, no chance. Everything had to pass through one narrow choke point, and detail got crushed.

That last problem is the one attention was invented to solve. Which brings us to Seq2Seq.

Check your understanding

Why can't you speed up LSTM training by throwing more GPU cores at a single sequence?

Show answer ▸

Because the computation is inherently sequential: step tt needs the hidden state from step t1t-1, which needs step t2t-2, and so on. There's no way to compute step 500 before step 499 exists, so the timesteps can't run in parallel. More cores don't help when the work forms a strict chain. The transformer's big structural win is removing this chain so the whole sequence can be processed at once.

5. Seq2Seq, and the birth of attention

The Seq2Seq model, built by Ilya Sutskever and colleagues, used two LSTMs: an encoder that reads the input sentence and produces a single final hidden state, and a decoder LSTM that gets initialized with that state and generates the output one token at a time.

The flow looked like this:

encoderbottleneckdecoderLSTMILSTMlikeLSTMcatsh_encone fixed vectorh_enc···LSTM···LSTM···LSTMh_enc = the entire input sentence compressed into one vector.Everything the decoder knows about the input has to fit there. That's the squeeze.
tokens: 0/3short input fits comfortably
Fig 5.4 — Everything the decoder knows about the input has to fit in one vector. That's the squeeze.

It worked well for tasks like translation. But the architecture had one glaring weak point: every drop of information from the input had to flow through that single fixed-size vector hench_{enc}. For short sentences, fine. For paragraphs, the model had forgotten the beginning of the input by the time it was generating the end of the output.

This is where attention enters the story — first as an add-on to Seq2Seq, introduced by Bahdanau and colleagues in 2014. The idea was simple and, in hindsight, enormous: don't force the decoder to lean on one summary vector. Instead, let it look back at every encoder hidden state and decide, at each generation step, which ones are relevant right now.

Here's the algorithm:

1. The decoder's current hidden state si1s_{i-1} acts as a query.

2. Each encoder hidden state hjh_j acts as a key/value.

3. Compute a compatibility score: how relevant is encoder state hjh_j to the current decoder state si1s_{i-1}?

4. Softmax those scores into weights αij\alpha_{ij} that sum to 1.

5. Compute a context vector ci=jαijhjc_i = \sum_j \alpha_{ij} \, h_j — a weighted sum of the encoder states.

6. Use cic_i alongside the decoder hidden state to generate the next token.

Let's name those symbols, because they'll come back in a big way:

  • si1s_{i-1} — the decoder's hidden state at the previous output step; the thing "asking the question."
  • hjh_j — the jj-th encoder hidden state; one per input word.
  • αij\alpha_{ij} — the attention weight: how much output step ii should focus on input word jj. The row sums to 1.
  • cic_i — the context vector for output step ii; a custom-built summary tilted toward whatever's relevant right now.

Take a sec to let that sink in, because here's the punchline: this is the exact same mechanism that becomes the centerpiece of the transformer. Modern self-attention is just this idea, generalized — instead of the decoder attending to the encoder, every token attends to every other token. Same query/key/value skeleton, same softmax-weighted sum.

One historical detail. In Bahdanau attention, the compatibility score was computed by a small feedforward network — different from the dot product the transformer would later use. But conceptually, this was the birth of attention.

Here's that original attention mechanism laid out:

Bahdanau Attention (added to Seq2Seq)encoder hidden statesa_i1a_i2a_i3a_i4a_i5h1"the"h2"cat"h3"sat"h4"on"h5"mat"s_(i-1)decoder queryc_i = sum a_ij h_jweighted blend of h1..h5-> feeds decoder step 1attention weights (Σ=1)0.62h10.18h20.09h30.06h40.05h5
decoder step:

The decoder builds a custom summary each step instead of reusing one frozen vector.

Fig 5.5 — The decoder builds a custom summary each step instead of reusing one frozen vector.

What did attention buy us? It demolished the squeeze. Translation quality jumped, especially on long sentences. Suddenly the decoder could zero in on whichever input word mattered most at each output step — which is, not coincidentally, exactly how a human translator works.

And the lesson reached past translation. It showed that direct token-to-token interaction across the sequence, weighted by learned attention, was a more powerful idea than threading everything through a recurrent hidden state. The model was no longer limited by what it could squeeze into one vector — it could pull from anywhere it needed.

Check your understanding

In Bahdanau attention, what plays the role of the "query," and what do the attention weights αij\alpha_{ij} actually represent?

Show answer ▸

The decoder's current hidden state si1s_{i-1} is the query — it's what's asking "which input words matter for what I'm about to generate?" Each αij\alpha_{ij} is the weight on input word jj for output step ii; the weights for a given output step are softmaxed so they form a distribution summing to 1, and the context vector is the α\alpha-weighted sum of encoder states. This query/key/value-and-softmax pattern is exactly what self-attention generalizes.

6. The leap: "Attention Is All You Need"

For a few years, the architecture of choice was "LSTM + attention." It worked — but the LSTM part was still slow and stubbornly sequential. Every step waited on the previous step. The attention was the good part; the recurrence was the part dragging everything down.

Think about what that meant. Even with attention bolted on, you still couldn't parallelize across timesteps. You still couldn't truly feed a GPU the dense work it loves. And you were now doing more total computation — both the recurrent step and the attention step — at every position.

So Vaswani and his coauthors asked the obvious question in 2017: what if we keep only the attention? What if we throw out the LSTMs entirely and let attention do all the work? If attention is the part actually solving the long-range problem, why are we still paying for recurrence at all?

The answer was "Attention Is All You Need," and the transformer was born. The very mechanism that started life as a helper for LSTMs became the entire architecture.

Check your understanding

By 2017, attention was already working well as an add-on to LSTMs. What was the key realization that produced the transformer?

Show answer ▸

That the recurrence was no longer pulling its weight. Attention was doing the heavy lifting on long-range dependencies, while the LSTM backbone was forcing sequential, un-parallelizable computation and extra work per step. The transformer's move was to drop recurrence entirely and let attention handle everything — which unlocked full GPU parallelism across the sequence.

The Transformer (Attention Is All You Need)

Let's start by looking at the transformer as a black box, then crack it open and study the system underneath.

At the highest level, a transformer is an architecture that takes a sequence in and produces a sequence out. The classic example is translation — an English sentence in, a French sentence out.

Why did this one architecture kick off the entire AI revolution? Because it fixed, all at once, everything the earlier models kept tripping over. The models we just walked through struggled to remember context over long paragraphs — when they tried, they either forgot too fast (small effective memory) or went unstable. They were hard to scale because they couldn't take advantage of GPUs, and all that compression crushed the detail out of long inputs. The transformer flips every one of those:

  • it lets the model attend to all the words at once, through self-attention;
  • it trains far faster by processing the whole sequence in parallel;
  • it preserves word order with positional encodings;
  • and it scales beautifully — stack more layers, add more data, and it keeps improving.

Internally, the transformer splits into two halves: an encoder that processes the input, and a decoder that generates the output. In the original paper, each half is a stack of 6 identical layers. Both halves are built from the same kit of components — they just use them a little differently.

A few things to lock in from the big picture. The encoder processes the entire input in parallel, producing a rich representation of every input token in context. The decoder generates the output one token at a time, and each step looks at (1) what the decoder has already produced and (2) the full encoder output, via cross-attention. That one-token-at-a-time pattern is autoregressive generation, and every modern LLM still works this way. (Quick definition: an autoregressive model predicts the next value in a sequence from the previous values in that same sequence.)

Also worth flagging early, since it reframes everything that follows: modern LLMs like GPT are technically decoder-only transformers — they keep the decoder stack and drop the encoder, because for pure text generation you don't need a separate "input" to encode. BERT is encoder-only for the opposite reason — it builds representations of text but never needs to generate. The original encoder–decoder design is most natural for input-to-output tasks like translation. We'll come back to all three.

Inside these blocks, five core mechanisms work together: attention (in four variants), feed-forward networks, layer normalization, positional encoding (added to the initial input embeddings), and residual connections. We'll cover every one.

Here's the whole thing, assembled — the full encoder–decoder transformer:

The Transformer (Encoder–Decoder)EncoderDecoderNx = 6Nx = 6K, V from encoderInput Embedding+ Positional EncodingMulti-HeadSelf-AttentionAdd & NormFeed-ForwardNetworkAdd & NormOutput Embedding(shifted right) + PEMasked Multi-HeadSelf-AttentionAdd & NormCross-AttentionQ:dec K,V:encAdd & NormFeed-ForwardNetworkAdd & NormLinearSoftmaxOutput Probabilitiesinputs (source tokens)outputs (shifted right)Cross-Attention Q:dec K,V:encDecoder queries attend over the encoder's output — this is where the twostacks meet.softmax(Q_dec K_encᵀ/√d_k) V_encEncoder reads the input in parallel; decoder emits output one token at a time.
layers:N = 6
Fig 5.6 — Encoder understands the input in parallel; decoder generates the output one token at a time.

Before we get into the attention mechanisms themselves, there's some housekeeping to do. We need to talk about how raw text even becomes something a transformer can chew on. That's tokenization and embeddings.

Housekeeping: tokenization and embeddings

Tokenization. Computers speak in numbers; humans speak in words. Tokenization is the bridge: we break raw text into smaller, manageable units called tokens. A token might be a whole word, a piece of a word ("token" + "ization"), or even a single character, depending on the scheme. Common approaches include word-level (split on spaces — simple but the vocabulary explodes and rare words break it), character-level (tiny vocabulary, but sequences get very long and meaning is thin per token), and the modern workhorse, subword tokenization like Byte-Pair Encoding (BPE) or WordPiece, which strikes a balance: frequent words stay whole, rare words split into reusable pieces, and you never hit a word you literally can't represent.

Tokenization"Transformers are powerful!"Subword (BPE) → 6 tokensTransform9176ers1010are47052power22957ful2057!133Word4 tokCharacter24 tokSubword (BPE)6 tokSubword tokenization balances vocabulary size against sequence length.
Fig 5.7 — Subword tokenization balances vocabulary size against sequence length.

Token embeddings. Once text is split into tokens, each token is mapped to a unique token ID from a predefined vocabulary. But transformers don't compute on raw integers — they work with vectors. This is where embeddings come in. An embedding is a numerical representation of an object (like a word) that turns high-dimensional, sparse data into a dense, lower-dimensional vector living in a continuous space — the embedding space — where semantically similar items sit close together.

The machinery is an embedding matrix: a big table with VV rows and dd columns, where:

  • VV — the vocabulary size, the total number of unique tokens the model knows.
  • dd — the embedding dimension, the length of the vector representing each token (you can think of it as the number of learned "features" per token). It's a hyperparameter you choose.

Each row of this matrix is a trainable vector for one token. For example: Transform → token ID 1231 → [0.1, 0.3, 0.4, 0.9, 0.8]. Once every input ID is swapped for its embedding, the entire input sentence becomes a 2D tensor of shape (number of tokens, dd).

The takeaway: after tokenization and embedding, every token is a vector that carries semantic meaning, and similar concepts (bat, swung, hit, ball) land near each other in this high-dimensional space. These embeddings are what the transformer actually operates on.

Token EmbeddingsTransformtoken id 1231embedding matrixV = vocab sized = embedding dimrow 1231vector0.10.30.40.90.8Transform2D projection of embedding spacebatballswunghitphilosophyEach token becomes a trainable vector; similar meanings sit close together.
Hover a word in the scatter to see its vector.
Fig 5.8 — Each token becomes a trainable vector; similar meanings sit close together.

Check your understanding

What do VV and dd stand for in the embedding matrix, and why can't the transformer just use the raw token IDs?

Show answer ▸

VV is the vocabulary size (how many distinct tokens exist) and dd is the embedding dimension (the length of each token's vector). Raw token IDs are just arbitrary labels — ID 1231 isn't "more" than ID 5, and nearby IDs aren't semantically related. Embeddings replace each ID with a learned vector so that distance and direction in the space carry meaning, which is the kind of input the transformer's matrix math can actually work with.

Self-Attention — the heart of the transformer

Time for the main event. Let's build the intuition first.

Consider the sentence: "I like this girl." The word like is ambiguous on its own — is it the "similar to" like, or the "fond of" like? How do you, or a model, know which one we mean? The answer is context. As humans, we see that girl is sitting right there in the same sentence, so we connect like to fondness rather than similarity. Self-attention is the mechanism that lets each word do exactly this — gather context from the other words around it. Instead of treating each word in isolation, every word looks at every other word and decides how much each one matters for its own meaning.

Now the single most important idea in transformers: Q, K, V.

Self-attention gives each token three different vectors, all derived from its embedding:

1. Query (Q) — what this token is looking for.

2. Key (K) — what this token offers to others.

3. Value (V) — the actual content this token will contribute.

The classic analogy is a search engine. You type a search into the bar (your query). The engine matches your query against page titles and keywords (the keys). Where it finds good matches, it hands back the actual page content (the values).

In self-attention, every token does this at the same time. Each token acts as a query searching across all the other tokens (including itself), finds its best matches based on key similarity, and pulls in their values — weighted by how good each match was.

Computing Q, K, V

Q, K, and V are computed from the input embeddings via three learned weight matrices: WQW_Q, WKW_K, WVW_V. For an input embedding xx:

Q=xWQ,K=xWK,V=xWVQ = x W_Q, \quad K = x W_K, \quad V = x W_V

Reading the symbols:

  • xx — the token's input embedding (its vector from the embedding step, with position added).
  • WQ,WK,WVW_Q, W_K, W_V — the three learned projection matrices. These are trained via backpropagation; the model figures out on its own what makes a good query, key, and value for the task at hand.
  • Q,K,VQ, K, V — the resulting query, key, and value vectors for that token.

For a sequence of nn tokens you do this for the whole sequence at once with a single matrix multiplication, producing three matrices of shape (n,dk)(n, d_k), (n,dk)(n, d_k), and (n,dv)(n, d_v) — where dkd_k is the dimension of the query/key vectors and dvd_v the dimension of the value vectors.

The compatibility function: dot product

Once each token has its query and every token has its key, we need to measure how well each query matches each key. The transformer uses the dot product — a simple, fast operation that measures how aligned two vectors are. A large positive dot product means strong alignment (relevant); near zero means orthogonal (irrelevant); negative means anti-aligned.

For each query QiQ_i and key KjK_j, compute QiKjQ_i \cdot K_j. Arrange all of these into a compatibility matrix of shape (n,n)(n, n), where entry (i,j)(i, j) answers "how much should token ii pay attention to token jj?"

But there's a wrinkle. When you're working with high-dimensional vectors, those dot products between QQ and KK can get very large. Large values feed into the softmax and push it into a region where its gradients become tiny — and you'll recognize that immediately as our old enemy, the vanishing-gradient problem, showing up in a brand-new place. On top of that, raw dot products aren't probabilities; we want each query's attention to spread across the keys and sum to 1.

So we fix both issues in two steps.

Step 1 — scale. Divide the scores by dk\sqrt{d_k}:

scores=QKdk\text{scores} = \frac{Q K^\top}{\sqrt{d_k}}

Here dkd_k is the dimensionality of the key vectors (a hyperparameter that sets the size of the subspace the embeddings get projected into), and dk\sqrt{d_k} is the scaling factor. Dividing by it keeps the scores in a sensible range no matter how large the dimension gets, which keeps the softmax in its healthy, well-gradiented zone.

Step 2 — softmax. Apply softmax along each row, turning the scores into a probability distribution. Each row of the resulting attention-weight matrix sums to 1, telling us "of all the tokens, here's the fraction of attention this token should pay to each."

Finally, multiply those attention weights by the value matrix VV. The values carry the actual content each token contributes, so the weighted sum produces a new representation for each token that blends in exactly the context it found relevant.

Put it all together and you get the whole mechanism in one line:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

That's it — the entire self-attention mechanism in one equation. Read it right to left: take queries and keys, measure their alignment with a dot product, scale it down by dk\sqrt{d_k}, softmax to get probabilities, then use those probabilities to take a weighted blend of the values. For its time, this was a massive breakthrough — and it's still the beating heart of every model we'll discuss.

Here's the full mechanism, visualized:

Scaled Dot-Product Self-AttentionAttention(Q,K,V) = softmax( Q Kᵀ / √dₖ ) V1. x_i → Q,K,VIQ·K·VlikeQ·K·VthisQ·K·VgirlQ·K·VW_Q W_K W_V2. Q Kᵀ 3. ÷ √dₖ 4. softmax row → weightskey token j →IlikethisgirlIlikethisgirlquery token i ↓0.890.060.030.020.010.120.010.870.010.040.890.060.010.070.020.90Σ=1Σ=1Σ=1Σ=15. × Vz1z2z3z4outputsrow "like" highlighted — click a token to query it, hover a cell for its weightscaling ON: logits divided by √dₖ = 8 (softmax stays soft)
query token:

Every token queries every token, scales, softmaxes, then pulls a weighted blend of values.

Fig 5.9 — Every token queries every token, scales, softmaxes, then pulls a weighted blend of values.

Check your understanding

Why do we divide the attention scores by dk\sqrt{d_k} before the softmax?

Show answer ▸

Because in high dimensions the raw dot products QKQ \cdot K can grow large, and large inputs push softmax into a flat region where gradients shrink toward zero — the vanishing-gradient problem again. Dividing by dk\sqrt{d_k} (the square root of the key dimension) rescales the scores back into a range where softmax stays sensitive and trainable, regardless of how big the embedding dimension is.

Multi-Head Attention

Self-attention lets a model work out which words matter to each other. But there's more nuance in language than a single attention pattern can capture. Take the sentence "He swung the bat with incredible force." One relationship worth tracking is swungbat; a totally different one is incredibleforce. Multi-head attention lets us look at all of these relationships in parallel. Each head learns its own slightly different way of paying attention — one might focus on grammar, another on meaning, another on something like emphasis — and when you combine them, you get a far richer understanding of the sentence, much closer to how we read it.

Instead of one Q, K, V projection, you create hh different sets of them (the original paper used h=8h = 8 heads). Each head gets its own learned WQW_Q, WKW_K, WVW_V, but each works on a smaller slice of the embedding space — typically d/hd/h dimensions per head. With d=512d = 512 and 8 heads, each head gets 64 dimensions.

Each head independently runs the full scaled-dot-product attention we just built, producing its own output. Then the hh outputs are concatenated back together and passed through one final projection matrix WOW_O that mixes the information from all heads:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O
where headi=Attention(QWQ(i),KWK(i),VWV(i))\text{where } \text{head}_i = \text{Attention}(Q W_Q^{(i)}, K W_K^{(i)}, V W_V^{(i)})

The symbols: hh is the number of heads; headi\text{head}_i is the output of the ii-th head; WQ(i),WK(i),WV(i)W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} are that head's own projection matrices; Concat\text{Concat} glues the head outputs back into one vector; and WOW_O is the output projection that blends them. The shape comes out the same as if you'd run a single attention over the full dimension — you've just done it in parallel, specialized slices.

Later research gave us a nice interpretation: different heads specialize. Some learn grammatical patterns (subject–verb agreement), some learn semantic links (which words refer to the same entity), some learn positional habits ("look at the previous token"). Nobody tells them what to learn — the model sorts it out through training.

Why 8 heads? Eight isn't magic — it's a hyperparameter chosen to balance two failure modes. Too few heads and each one has to learn too many relationships at once, losing its specialization. Too many heads and each head's slice of dimensions gets so small it can't represent anything meaningful, and you pay more in compute for the privilege. With d=512d = 512 and h=8h = 8, each head gets 64 dimensions — diverse enough to learn several patterns, large enough to stay useful. Modern large models often use far more heads (32, 64, even 128), with proportionally smaller per-head dimensions.

Multi-Head Attentionh = 8 parallel heads, d_k = d_v = 512/8 = 64inputd=5128 parallel headsconcatmixWQKVattn 1WQKVattn 2WQKVattn 3WQKVattn 4WQKVattn 5WQKVattn 6WQKVattn 7WQKVattn 8Concat8x64=512W_Oprojectout d=512head 1 attends:swung -> batHeswungthebatwithincredibleforceEach head attends differently and in parallel; concat + W_O mixes them back together.
head:heads h:8d_k = 512/8 = 64
Fig 5.10 — Each head attends differently and in parallel; concat + W_O mixes them back together.

Check your understanding

With d=512d = 512 and h=8h = 8 heads, how many dimensions does each head get, and what's the danger of using too many heads?

Show answer ▸

Each head gets d/h=512/8=64d/h = 512/8 = 64 dimensions. If you push the head count too high, each head's slice of the embedding shrinks until it's too small to represent meaningful relationships — and you also pay more compute. Too few heads has the opposite failure: each head is overloaded trying to learn many patterns at once and loses specialization. Eight is just a balance point.

Masked Self-Attention (decoder side)

In the decoder, the first attention layer gets one small but critical modification. In regular self-attention, every token can "see" every other token in the sequence — which is totally fine in the encoder, since the whole input is known up front. In masked self-attention, the difference is, almost literally, just a mask: the model is forbidden from looking at future tokens when predicting the next word.

Mechanically, a look-ahead mask is applied to the scaled dot-product score matrix, setting every entry above the diagonal to negative infinity before the softmax. That guarantees each token can only attend to itself and the tokens before it, which preserves the autoregressive property — the model writes strictly left to right.

Why negative infinity, specifically? Because softmax turns -\infty into 00. You're effectively telling the softmax that those future positions don't exist, so they get zero attention weight.

And why bother with all this? Here's the intuition. If we didn't mask, the model would already be able to peek at the correct future tokens while predicting the next word — so there'd be no real learning, just copying. It would hit a perfect training loss by cheating and never learn to generate text on its own at inference, when those future tokens genuinely aren't available yet.

Masked Self-Attentionkey (attend to)query (token)IlikethisgirlIlikethisgirl000000allowed: self + earlier tokens-inf -> 0 after softmax (future, blocked)all tokens generated

Block the future with a -inf mask so the model learns to predict, not copy.

Fig 5.11 — Block the future with a -inf mask so the model learns to predict, not copy.

Check your understanding

Why set the masked entries to negative infinity instead of, say, zero?

Show answer ▸

Because the mask is applied before the softmax. Softmax exponentiates its inputs, so e=0e^{-\infty} = 0 — those positions end up with exactly zero attention weight and the remaining (allowed) positions still form a clean probability distribution that sums to 1. Setting the raw scores to 0 wouldn't work, since e0=1e^{0} = 1 would leave the future tokens with plenty of attention.

Cross-Attention (decoder side)

After the decoder applies masked self-attention to its own generated tokens, it still needs to actually look at the input. The encoder did all that work understanding "I like cats" — so how does the decoder get at it?

Cross-attention is the bridge. The mechanism is exactly the same scaled-dot-product attention as before, with one twist in where the vectors come from:

Query (Q)
— from the decoder's own hidden state (what we've generated so far).
Key (K)
— from the encoder output (the processed input).
Value (V)
— from the encoder output (the actual input content).

So when the decoder is about to generate the next French word, it forms a query that essentially asks "given what I've written so far, which English words are relevant right now?" — and pulls the matching values out of the encoder's representation. This is how the decoder lines up what it's generating with what the encoder understood. When the decoder produces "chats," cross-attention is what makes it look back at the encoder's representation of "cats" to know what to say.

One note: cross-attention only exists in encoder–decoder transformers (like the original). Decoder-only models like GPT don't have it — there's no separate encoder to attend to.

Cross-AttentionEncoder outputK, V come from hereDecoder stateQ comes from here"I"K1, V1"like"K2, V2"cats"K3, V30.550.340.11Q (query)producing:J'aime(I like)step 1: "J'aime" attends most to "I"Q from the decoder, K and V from the encoder — the decoder looks back at the input.
generate:

Step through the French output and watch cross-attention shift to the relevant English word.

Fig 5.12 — Q from the decoder, K and V from the encoder — the decoder looks back at the input.

Check your understanding

In cross-attention, where do Q, K, and V each come from, and which models lack cross-attention entirely?

Show answer ▸

The query comes from the decoder (what it's generated so far); the keys and values both come from the encoder output (the processed input). Decoder-only models like GPT have no cross-attention at all, because they have no separate encoder to attend to — they fold the input into the same sequence the decoder generates.

Feed-Forward Networks (FFN)

Attention has now gathered context for each token. But the model still needs to process that context — and that's the job of the feed-forward network inside every transformer layer.

If attention answers "what should I pay attention to?", the FFN answers "okay, now that I know what to focus on, what do I actually do with it?"

Each transformer layer has an FFN that processes each token independently — the same little network applied at every position. It's a simple two-layer MLP:

FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2

Symbol by symbol: xx is the token's representation coming out of attention; W1,b1W_1, b_1 are the weights and bias of the first (expansion) layer; max(0,)\max(0, \cdot) is the ReLU nonlinearity; W2,b2W_2, b_2 are the weights and bias of the second (contraction) layer. It runs in three steps:

1. Expansion — project from the embedding dimension dmodeld_{model} up to a larger dffd_{ff} (in the original paper, dmodel=512dff=2048d_{model} = 512 \to d_{ff} = 2048, a 4× expansion). This gives the model room to detect complex features.

2. Nonlinear activation — apply ReLU (or GELU in more modern variants). This step is essential: without a nonlinearity, stacking layers would just collapse into one big linear transformation, and depth would buy you nothing.

3. Contraction — project back down from dffd_{ff} to dmodeld_{model}, so the output matches the input shape and can flow into the next layer.

A practical fact that surprises people: the FFN holds most of the parameters in a transformer. With dmodel=512d_{model} = 512 and dff=2048d_{ff} = 2048, each FFN layer has roughly 4×512×204844 \times 512 \times 2048 \approx 4 million parameters. Multi-head attention with the same dmodeld_{model} has only about 4×512214 \times 512^2 \approx 1 million. In large LLMs, the FFN is where most of the model's knowledge actually lives.

That's exactly why an innovation like Mixture of Experts (MoE) targets the FFN specifically — it swaps the single dense FFN for many smaller "expert" FFNs and routes each token to just a few of them, letting you blow up the total parameter count without a matching blowup in compute. More on that later.

Position-wise Feed-Forward NetworkFFN(x) = W2 · ReLU(W1 · x)in512h2048aReLUout512W1 expandW2 contractparams / layerFFN ~4Mattn ~1Mapplied per token:x1 (same W1, W2)x2 (same W1, W2)x3 (same W1, W2)x4 (same W1, W2)ReLU on: a true nonlinearity between W1 and W2
nonlinear — two distinct layers
Fig 5.13 — Attention decides what to mix; the FFN decides what to do with it. Most parameters live here.

Check your understanding

What breaks if you remove the nonlinearity (ReLU) from the FFN, and why does that matter for a deep transformer?

Show answer ▸

Without a nonlinearity, the FFN is just two linear layers back to back — and a composition of linear maps is itself a single linear map. Stack as many as you like and the whole thing collapses to one linear transformation, so depth gives you no extra expressive power. The ReLU (or GELU) is what lets stacked layers learn genuinely richer, non-linear functions.

Layer Normalization

This one is essential for keeping training stable — and, you guessed it, for keeping our gradients from exploding or vanishing.

The idea: we normalize the activations so they always have mean 0 and standard deviation 1, then let the model learn how to rescale them through two trainable parameters:

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

The symbols:

  • xx — the activation vector for a single token.
  • μ\mu — the mean, computed across the features of that one token.
  • σ2\sigma^2 — the variance, also across that token's features.
  • ϵ\epsilon — a tiny constant added for numerical stability (so we never divide by zero).
  • γ\gamma — a learnable scale parameter.
  • β\beta — a learnable shift parameter.

The γ\gamma and β\beta are the clever bit: they let the model "un-normalize" when that's actually useful, so normalization never costs it expressive power.

Layer norm vs. batch norm. Batch norm computes its statistics across a whole batch of examples. Layer norm computes them across the features of a single example. For transformers, layer norm wins for three reasons: it doesn't depend on batch size (transformers get trained with wildly varying batch sizes), it works fine with variable-length sequences, and it's compatible with autoregressive generation, where at inference you process one token at a time and simply don't have a batch to normalize over.

Modern variants like RMSNorm drop the mean-subtraction step entirely and just normalize by the root-mean-square, which saves a little compute. LLaMA and many recent LLMs use it — we'll dig into exactly why it works later.

Layer Normalization vs Batch Normalizationfeatures →batch ↓1.2-0.40.80.1-0.71.50.2-1.10.50.3-1.30.9-1.40.61.10.4normalize across a row (one token's features)Layer Norm statsμ = -0.03σ = 1.00over 4 featuresbatch-size independenty = γ·(x − μ)/√(σ² + ε) + βLayer norm normalizes across one token's features —batch-size independent, generation-friendly.
Fig 5.14 — Layer norm normalizes across one token's features — batch-size independent, generation-friendly.

Check your understanding

Why is layer normalization preferred over batch normalization in transformers, especially at inference?

Show answer ▸

Layer norm computes its mean and variance across a single token's own features, so it doesn't depend on the batch at all. That matters at inference time during autoregressive generation, where you're producing one token at a time and there's effectively no batch to compute statistics over. It also handles variable-length sequences and wildly varying batch sizes gracefully — all situations where batch norm struggles.

Residual Connections

Residual connections are the reason deep transformers work at all. Here's the problem they solve. If you stack many layers, the early layers have a hard time learning, because their gradient has to travel all the way back down through every layer above them — and we've seen what long backward paths do to a gradient. Layer norm helps, but it doesn't fully fix it.

The fix came from ResNet (2015): residual connections, also called skip connections. The idea is to add a direct path that bypasses each sublayer:

output=LayerNorm(x+Sublayer(x))\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))

where xx is the input to the sublayer and Sublayer(x)\text{Sublayer}(x) is whatever that sublayer computes (attention or the FFN). The magic is the +x+ x. Why does this help so much?

First, gradients flow freely. That +x+ x creates a direct highway during backpropagation. Even if Sublayer(x)\text{Sublayer}(x) produces a tiny gradient, xx's gradient passes straight through unimpeded. This is what stops gradients from vanishing across dozens of stacked layers.

Second, it's easy to learn the identity. Without residuals, every layer has to learn its full transformation from scratch. With residuals, a layer only needs to learn the delta — what to add to the input. If a layer doesn't need to do anything useful, it can just output zero and the residual passes the input through untouched. Learning "add nothing" is far easier than learning "be the identity function."

Residual connections wrap every sublayer in the transformer — both the multi-head attention and the feed-forward network. Without them, transformers with 6, 12, or 96 layers simply wouldn't train.

One modern variation worth knowing: pre-norm vs. post-norm. The original transformer applied the norm after the addition (post-norm). Modern models usually use pre-norm — applying LayerNorm to the input before the sublayer, then adding the unchanged residual. Pre-norm is more stable for very deep networks and is now the default in most modern LLMs.

Residual Connection (Add & Norm)post-norm: norm after addx+xskipSublayer(Attention or FFN)+AddLayerNormoutout = LayerNorm(x + Sublayer(x))the +x skip path bypasses the sublayer entirely
norm:

The +x skip path gives gradients a clear road back, so deep stacks stay trainable.

Fig 5.15 — The +x skip path gives gradients a clear road back, so deep stacks stay trainable.

Check your understanding

Two distinct benefits come from the +x+x in a residual connection. What are they?

Show answer ▸

(1) Gradient flow: the skip path is a direct route for gradients during backprop, so even if a sublayer contributes almost nothing, the input's gradient passes straight through — preventing vanishing gradients across many layers. (2) Easy identity: each layer only has to learn the change to apply to its input (the delta), and can effectively "do nothing" by outputting zero, which is much easier to learn than reconstructing the identity function from scratch.

Positional Encoding

There's one big problem we haven't addressed yet. Self-attention treats the input as a set of tokens, not a sequence. Without extra information, "I like cats" and "cats like I" would look identical to the model — the attention computation comes out the same regardless of the order the tokens arrive in.

But order is everything in language. Positional encodings fix this by injecting position information directly into the embeddings before they enter the first transformer layer. Each position in the sequence gets its own dd-dimensional vector — the positional encoding for that position — and it's simply added to the token embedding sitting there:

inputi=token_embedi+pos_encodei\text{input}_i = \text{token\_embed}_i + \text{pos\_encode}_i

where token_embedi\text{token\_embed}_i is the embedding of the token at position ii, and pos_encodei\text{pos\_encode}_i is the positional vector for that position. Add them and the token now "knows" where it sits.

Sinusoidal positional encoding

The original transformer used a neat scheme built from sine and cosine waves at different frequencies:

PE(pos,2i)=sin ⁣(pos100002i/d)\text{PE}_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right)
PE(pos,2i+1)=cos ⁣(pos100002i/d)\text{PE}_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)

The symbols: pospos is the position in the sequence (0, 1, 2, …); ii indexes the dimension of the encoding vector; dd is the embedding dimension; and 1000010000 is a constant chosen to spread the frequencies across many scales. Even dimensions (2i2i) use sine, odd dimensions (2i+12i+1) use cosine.

Each dimension of the encoding oscillates at a different frequency. Low dimensions wiggle fast (capturing fine, nearby position differences); high dimensions wiggle slowly (capturing long-range position). Together they produce a unique "fingerprint" for every position.

Why design it this way? Three reasons:

Every position gets a unique signature — no two positions share the same encoding, because the combination of frequencies never repeats over the range you care about.

The model can read off relative distances. This is the deep reason for using trig functions. Thanks to the sine/cosine angle-addition identities, the encoding for position pos+kpos + k can be written as a fixed linear transformation (a rotation) of the encoding for position pospos. So "look kk positions back" becomes a simple linear operation the model can learn — relative position awareness comes basically for free. The waves are periodic, but using many different frequencies (by varying ii) guarantees each position still gets a unique overall combination.

It generalizes past the training length. Since sine and cosine are defined for every input, you can compute the encoding for any position, even one longer than anything seen during training.

Sinusoidal Positional EncodingPE(pos, 2i) = sin(pos / 10000^(2i/d)) ; PE(pos, 2i+1) = cos(...)position x dimensiondimension 0 .. d (fast -> slow)position 0 .. 50pos 12individual frequency wavesposition (time) ->dim 0 sin (fast)dim 4 sin (med)dim 10 sin (slow)fingerprint @ pos 12:d0=-0.54d4=0.93d10=0.04token_embed ᵢ + pos_encode = inputDifferent frequencies give every position a unique fingerprint and encode relative distance.
Drag to move the fingerprint line and lit row.
Fig 5.16 — Different frequencies give every position a unique fingerprint and encode relative distance.

Modern positional encoding

Sinusoidal encodings are elegant, but most modern LLMs (LLaMA, GPT-4-era models) use Rotary Positional Embeddings (RoPE) instead. Rather than adding a positional vector to the embedding, RoPE rotates pairs of dimensions inside the query and key vectors by an angle that depends on position — so when you later compute the dot product QKQ \cdot K for attention, the rotations naturally produce a term that depends on the difference between the two positions, not their absolute values. ALiBi (Attention with Linear Biases) is another modern alternative that biases attention scores directly by relative distance. We'll cover both in more depth in the modern-divergences section.

The core idea never changes: somehow inject position so the model knows the order. The specifics keep evolving.

Check your understanding

Without positional encoding, why would "I like cats" and "cats like I" look the same to a transformer?

Show answer ▸

Because self-attention is permutation-invariant — it treats its input as an unordered set. The attention scores between a given pair of tokens depend only on their content vectors, not on where they sit in the sequence, so shuffling the tokens produces the same set of pairwise interactions. Positional encoding breaks this symmetry by adding position-specific information to each token's vector before attention sees it.

A full trace, to wrap up the architecture

Let's put every piece together. Here's exactly what happens when a sequence of tokens passes through a single transformer encoder layer:

1. Input: a sequence of token embeddings with positional encodings added, shape (n,d)(n, d).

2. Multi-head self-attention: project to Q, K, V across hh heads; compute scaled-dot-product attention per head; concatenate; project through WOW_O.

3. First Add & Norm: add the input back as a residual, then apply layer norm.

4. Feed-forward network: expand to dffd_{ff}, apply the nonlinearity, contract back to dd.

5. Second Add & Norm: add the attention output back as a residual, then apply layer norm.

6. Output: same shape (n,d)(n, d), handed to the next layer.

That's one encoder layer. Stack 6 of them (or 12, or 96) and you have an encoder. A decoder layer adds one extra step in the middle: masked self-attention followed by Add & Norm, then cross-attention to the encoder output followed by Add & Norm, then the FFN and a final Add & Norm.

And here's the genuinely beautiful part: every layer has the same shape coming out as going in — (n,d)(n, d) to (n,d)(n, d). That means you can stack as many as you want — 6, 12, 96, 120 — and the data just flows straight through, each layer refining the representation a little more. That compositional simplicity is a huge part of why transformers scale so gracefully.

After the last decoder layer, you've got a sequence of dd-dimensional vectors. To turn those into actual words, two more steps:

1. Linear projection to vocabulary size. A learned weight matrix maps each dd-dimensional vector to a VV-dimensional vector — one entry per vocabulary token. These raw scores are called logits.

2. Softmax over the vocabulary. Convert the logits into a probability distribution over the whole vocabulary. The token with the highest probability is the predicted next token.

At training time, you compute the cross-entropy loss between this predicted distribution and the true next token. At inference time, you sample from the distribution (or just take the argmax for greedy decoding) to pick the next token, then feed everything back through the decoder to predict the token after that, and so on — autoregression in action.

One Encoder Layer + One Decoder Layer (full trace)Encoder ×6Decoder ×6K, V from encoderInput(n,d)Multi-HeadSelf-Attn(n,d)Add & Norm(n,d)FFN(n,d)Add & Norm(n,d)Input(m,d)MaskedSelf-Attn(m,d)Add & Norm(m,d)Cross-AttnQ:dec K,V:enc(m,d)Add & Norm(m,d)FFN(m,d)Add & Norm(m,d)Linear(m,V)Softmax(m,V)next tokenargmaxsource tokenstargets (shifted right)Press Trace / Step to send a sequence through the layers.Same shape in, same shape out — that's why layers stack freely.Same shape in, same shape out — that's why you can stack layers freely.
stack:6 layers
Fig 5.17 — Same shape in, same shape out — that's why you can stack layers freely.

Check your understanding

What's the one structural difference between an encoder layer and a decoder layer, and what turns the decoder's final vectors into a predicted word?

Show answer ▸

A decoder layer inserts an extra sub-block: after its masked self-attention (+ Add & Norm), it has a cross-attention block (+ Add & Norm) that attends to the encoder's output, before the FFN. The encoder layer has no cross-attention. To produce a word, the decoder's final dd-dimensional vector is sent through a linear projection to vocabulary size (giving logits), then softmax turns those logits into a probability distribution over the vocabulary, and the highest-probability token is chosen (or sampled).

Where Modern Models Have Diverged

The original 2017 transformer is still the conceptual foundation. But modern LLMs have evolved several of its components — and here's the reassuring thing: every one of these is an optimization of the original recipe, not a replacement. Once you understand the architecture we just built, every modern paper clicks into place, because it's almost always improving one specific piece while leaving the overall shape intact.

Here's the short list before we go deep on each:

Decoder-only architecture
— GPT, LLaMA, Claude, and most modern LLMs drop the encoder entirely. For pure text generation you don't need a separate "input" stream, so the whole model is just a stack of decoder layers with causal masking.
Better positional encodings
— RoPE and ALiBi instead of sinusoidal, for better behavior on long contexts.
RMSNorm instead of LayerNorm
— slightly cheaper, comparable quality.
SwiGLU instead of ReLU
— a gated activation in the FFN that consistently beats ReLU at scale.
Grouped-Query Attention (GQA)
— fewer key/value heads than query heads, easing memory pressure at inference.
Flash Attention
— a re-implementation of attention that's mathematically identical but uses the GPU memory hierarchy carefully. Dramatically faster, especially on long sequences.
Sparse / sliding-window attention
— attend to only a local window instead of every token. Trades a little capability for big efficiency gains on long contexts.
Mixture of Experts (MoE)
— replace the dense FFN with many smaller experts and route each token to a few. Massively more parameters without proportional compute.
Long context windows
— the original handled hundreds of tokens; modern models handle millions.

Let's go through them.

Decoder-only transformers

The original transformer had two stacks: an encoder for the input language and a decoder for the output language. But when the task is pure text generation — predict the next token given everything so far — you don't really need two streams. The "input" and the "output" are the same sequence, just shifted by one position.

That realization gave us decoder-only models: GPT-1 in 2018, then GPT-2, GPT-3, GPT-4, Claude, LLaMA, and basically every modern frontier LLM. The encoder is gone. The cross-attention block is gone. What's left is a tall stack of decoder layers, each with masked self-attention and a feed-forward network.

It almost seems too simple. How can a model that just predicts the next token translate languages, write code, answer questions, and reason? The answer: next-token prediction over a huge, diverse corpus turns out to be an extraordinarily general training signal. To predict the next token of a Python function, the model has to understand programming. To predict the next token of a Shakespearean sonnet, it has to understand iambic pentameter. To predict the next token of a math proof, it has to understand algebra. By being forced to model everything people write, the model implicitly absorbs the structure of language, knowledge, and reasoning.

Tasks like translation, summarization, and Q&A become special cases — you just frame them as text:

…and let the model predict what comes next. That's the unified interface decoder-only models gave us: one architecture handling every text task by reframing it as completion.

The masked self-attention is what makes this work. Each token can only attend to itself and earlier tokens, never future ones — which preserves the autoregressive property and lets the same model both train (predicting all positions in parallel) and generate (one token at a time at inference).

Decoder-Only Transformer (GPT-style)next-token probabilitiesSoftmaxLinearDecoder Layer 3Masked Self-AttentionAdd & NormFeed-Forward (FFN)Add & Norm… × 1 more …Decoder Layer 1Masked Self-AttentionAdd & NormFeed-Forward (FFN)Add & NormToken + PositionalEmbeddingsencoderremovedcross-attnremovedcausal:attendbackwardonlypredictedtoken fedback insequence (prompt + generated):thecat
Fig 5.18 — Drop the encoder, keep masked self-attention, predict the next token. That's a modern LLM.

Check your understanding

A decoder-only model only ever learns to "predict the next token." How does that single objective produce a model that can translate, code, and reason?

Show answer ▸

Because next-token prediction over a massive, varied corpus forces the model to learn whatever structure makes the next token predictable. Predicting code well requires understanding syntax and logic; predicting poetry requires meter; predicting proofs requires math. The objective is narrow but the data is everything, so the model ends up internalizing language, facts, and reasoning patterns as a side effect. Specific tasks then become text-completion problems framed in the prompt.

Encoder-only transformers

Encoder-only transformers stack self-attention layers to process input sequences bidirectionally — every token attends to every other token, both forward and backward. This makes them great for language understanding, representation learning, and structured prediction, rather than text generation.

(Quick definition: representation learning is the set of techniques that automatically discover compact, structured representations — embeddings — for things like feature detection or classification, replacing hand-engineered features.)

An encoder takes a sequence of tokens and produces a contextualized representation for each one. For "I like cats," the encoder outputs three vectors — one per token — where each vector has soaked up information from the whole sentence. The output vector for "cats" isn't just "the cats embedding"; it's "the cats embedding, having paid attention to 'I' and 'like.'" These representations aren't predictions — they're rich features that downstream tasks can build on. The encoder is fundamentally a representation learner, not a generator.

Encoder vs. decoder: the attention difference. The encoder uses bidirectional self-attention — every token attends to every other token, in both directions. That's fine, because the encoder isn't generating anything; it just needs to understand the input, and looking ahead doesn't help you cheat if you're not predicting the next token. The decoder uses causal (masked) self-attention — each token attends only to itself and prior tokens — which is required for autoregressive generation. This single difference leads to wildly different training objectives.

You can't train an encoder with next-token prediction, because it sees the whole sequence at once and the task would be trivial (it could just read the answer). So the objective that made encoders work is Masked Language Modeling (MLM), introduced in BERT (Bidirectional Encoder Representations from Transformers). The procedure:

1. Take a sentence: "I like cats and dogs."

2. Randomly mask out about 15% of the tokens: "I [MASK] cats and [MASK]."

3. Train the encoder to predict the masked tokens from the surrounding context.

Because the encoder is bidirectional, filling in a "[MASK]" in the middle of a sentence requires looking at the words both before and after it. That forces the model to build deep, two-sided understanding — context flows in from everywhere.

BERT's exact recipe was a touch more elaborate. Of the 15% of tokens selected for masking, 80% get replaced with [MASK], 10% get replaced with a random token, and 10% are left unchanged. This trick stops the model from learning that "[MASK] is the only signal that a prediction is needed" — which would hurt it on real downstream tasks where no [MASK] tokens appear.

Encoder-Only Transformer + MLM (BERT)likedogsMLM headpredictsEncoder × L — bidirectional self-attentionI[MASK]catsand[MASK].inputpos 1 attends to: {0, 2, 3, 4, 5} → predict "like"context from BOTH sides (left + right)masking rule (15% of tokens)80%[MASK]10%random tok10%unchangedencoder: ↔ bothattention into the • nodeBidirectional attention + predict the masked words = deep two-sided understanding.
predict mask:attention:
Fig 5.19 — Bidirectional attention + predict the masked words = deep two-sided understanding.

BERT itself was a landmark: 12 encoder layers, 110M parameters in its base version, trained on Wikipedia and BookCorpus with MLM plus a next-sentence-prediction objective. From roughly 2018–2022, BERT and its descendants dominated NLP benchmarks.

Are encoders obsolete in the GPT era? Not at all — they've just specialized into the jobs where bidirectional understanding is the advantage:

Embedding models
Almost every modern text embedding model (OpenAI's text-embedding-3, Cohere Embed, sentence-transformers, BGE, Voyage) is an encoder. Text in, vector out — and that vector powers semantic search, RAG, clustering, and classification. The entire retrieval step of a RAG pipeline depends on encoder models.
Reranking
After a vector search returns candidate documents, a cross-encoder (an encoder that reads the query and a document jointly and outputs a relevance score) reranks them for higher quality. Slower than a vector lookup, but more accurate.
Classification and structured tasks
Sentiment analysis, intent detection, named-entity recognition, content moderation. When you have labels and just need a score or a class, a fine-tuned encoder is often faster, cheaper, and more accurate than a full LLM.
Encoder components in multimodal models
The vision side of vision-language models (CLIP, LLaVA, GPT-4V) is an encoder — typically a Vision Transformer (ViT) — that produces image embeddings fed into the decoder LLM.

So encoders didn't disappear; they migrated to the parts of the pipeline where building representations beats generating text.

Check your understanding

Why can't you train an encoder with plain next-token prediction, and what objective is used instead?

Show answer ▸

Because the encoder is bidirectional — every token already sees every other token, including the "next" one. Next-token prediction would be trivial: the model could just look ahead and copy the answer, learning nothing. Instead, encoders use Masked Language Modeling: randomly hide ~15% of tokens and train the model to reconstruct them from both-side context, which forces genuine bidirectional understanding.

Encoder–decoder models

A few important models still use the full encoder–decoder structure:

T5 (Text-to-Text Transfer Transformer)
frames every task as text-to-text. Translation, summarization, question-answering — all become "given input text, produce output text." The encoder reads the input, the decoder generates the output. Surprisingly effective, and still competitive.
BART
is like T5 but trained with denoising objectives: corrupt the input, then recover the original.
Flan-T5
is T5 instruction-tuned across many tasks — a strong, compact alternative to LLM-style models for structured work.

Encoder–decoder shines when the input and output are clearly distinct sequences with different roles — especially translation and summarization. Decoder-only models can handle these too, by treating input + output as one continuous sequence, but the explicit encoder–decoder split gives the model clearer built-in assumptions about which part is which.

Check your understanding

When does the explicit encoder–decoder split (like T5) have an edge over a decoder-only model?

Show answer ▸

When the input and output are genuinely distinct sequences with different roles — translation (source language → target language) and summarization (long document → short summary) are the classic cases. The separate encoder gives the model a clean, dedicated representation of the input to attend to via cross-attention, which is a helpful inductive bias. Decoder-only models can do these tasks by concatenating input and output, but they don't get that explicit structural separation.

The component upgrades

Now the pieces modern models swap in. We'll take them one at a time, since each is a self-contained improvement to a part you already understand.

Modern positional encodings, in depth

The original sinusoidal scheme worked, but it had two weaknesses. It encoded absolute positions when what really matters is relative distance ("the token 3 places back"), and it didn't extrapolate well past the maximum training length. Modern LLMs use two main alternatives.

RoPE (Rotary Positional Embedding). RoPE doesn't add a positional vector to the embedding. Instead it rotates pairs of dimensions in the query and key vectors by an angle that depends on position. Then, when you compute the attention dot product QKQ \cdot K, the rotations naturally produce a term that depends on the difference between the two positions, not their absolute values.

Mathematically, treat each pair of dimensions as a 2D vector and apply a rotation matrix:

Rθ=(cosθsinθsinθcosθ)R_\theta = \begin{pmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{pmatrix}

where θ\theta depends on both the position and the dimension. After rotation, the dot product between a rotated QQ at position mm and a rotated KK at position nn depends only on (mn)(m - n) — the relative offset. You get relative-position information for free, without changing the attention formula at all.

Why it matters in practice: it extrapolates better to sequences longer than training (especially with NTK-aware scaling and YaRN, which stretch the rotation frequencies), it naturally captures the relative position the model actually cares about, and it's used by LLaMA, Mistral, Qwen, and most modern open LLMs.

RoPE: Rotary Positional Embeddingθ=1.00 radquery @ pos mθ=2.50 radkey @ pos nR_θ =[ cos θ −sin θ ][ sin θ cos θ ]relative position (m − n) = -3relative angle (m − n)·θ₀ = -1.50 radscore q·k = cos((m − n)·θ₀) = 0.07depends only on (m − n), not on m or n alone

Rotate Q and K by a position-dependent angle; the dot product then sees only relative distance.

Fig 5.20 — Rotate Q and K by a position-dependent angle; the dot product then sees only relative distance.

ALiBi (Attention with Linear Biases). ALiBi takes a different route: don't touch the embeddings at all. Instead, add a position-dependent bias directly to the attention scores. Tokens that are far apart get a more negative bias, making them harder to attend to:

scoresij=qikjmij\text{scores}_{ij} = q_i \cdot k_j - m \cdot |i - j|

The symbols: qiq_i and kjk_j are the query and key vectors; ij|i - j| is the distance between positions ii and jj; and mm is a head-specific slope. Because different heads get different slopes, some heads attend mostly to nearby tokens while others can still reach far away. ALiBi is simpler than RoPE and extrapolates extraordinarily well — models trained at 2k context can sometimes handle 16k at inference. The tradeoff: it's a softer bias than RoPE and can be slightly weaker on tasks where exact distance matters.

Check your understanding

Both RoPE and sinusoidal encoding inject position, but RoPE is preferred for long contexts. What's the key property RoPE gives you?

Show answer ▸

RoPE makes the attention dot product depend only on the relative distance (mn)(m - n) between two tokens, not their absolute positions — because it rotates Q and K by position-dependent angles and the rotation angles subtract in the dot product. Relative distance is what the model actually needs, and this formulation extrapolates to longer sequences far better than absolute sinusoidal encodings, especially with frequency-scaling tricks like NTK-aware scaling and YaRN.

RMSNorm instead of LayerNorm

Recall LayerNorm does two things: subtract the mean (re-centering), then divide by the standard deviation (re-scaling). Empirical research turned up something surprising: the re-centering step barely matters. The model trains nearly as well if you only do the re-scaling. So RMSNorm (Root Mean Square Normalization) drops the mean subtraction entirely:

RMSNorm(x)=γx1dixi2+ϵ\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon}}

The symbols: xx is the token's activation vector; xix_i is its ii-th component; dd is the dimension; the denominator is the root-mean-square of the features; ϵ\epsilon is a small stability constant; and γ\gamma is a learnable scale. Notice what's missing compared to LayerNorm: there's no mean μ\mu and no shift β\beta. It just divides by the RMS of the features and applies a learnable scale.

The benefits are modest but real: roughly 10–15% faster than LayerNorm, slightly fewer parameters, and identical or marginally better quality. It's used in LLaMA, Mistral, PaLM, and many modern LLMs. It's one of those small wins that genuinely adds up when you're training billions of parameters over trillions of tokens.

Check your understanding

What does RMSNorm drop relative to LayerNorm, and why is that okay?

Show answer ▸

It drops the mean-subtraction (re-centering) step and the learnable shift β\beta — it only divides by the root-mean-square of the features and applies a learnable scale γ\gamma. Empirically, the re-centering turns out to contribute very little to training quality, so removing it costs almost nothing while saving ~10–15% of the normalization compute and a few parameters.

SwiGLU instead of ReLU

The FFN in the original transformer was Linear → ReLU → Linear. Modern models almost universally use a gated variant called SwiGLU.

The idea comes from GLU (Gated Linear Units): instead of one linear projection followed by an activation, use two linear projections and multiply them together, with the activation applied to only one of them:

SwiGLU(x)=(Swish(xW1))(xW2)\text{SwiGLU}(x) = (\text{Swish}(x W_1)) \odot (x W_2)

Here Swish(x)=xσ(x)\text{Swish}(x) = x \cdot \sigma(x) (also called SiLU) is a smooth nonlinearity, \odot is element-wise multiplication, and W1,W2W_1, W_2 are two separate learned projections. The output is then sent through a third linear layer to project back down:

FFNSwiGLU(x)=(Swish(xW1)xW2)W3\text{FFN}_{\text{SwiGLU}}(x) = (\text{Swish}(x W_1) \odot x W_2) W_3

with W3W_3 the down-projection. Because this uses three matrices instead of two, modern models shrink dffd_{ff} to keep the parameter count fair — the original used dff=4×dmodeld_{ff} = 4 \times d_{model}, while SwiGLU models typically use dff83×dmodeld_{ff} \approx \tfrac{8}{3} \times d_{model}.

Why does it work better? Intuitively, the multiplicative gate lets the network express more complex functions per parameter — one projection can dynamically modulate the other. Empirically, SwiGLU outperforms ReLU and GELU at scale across many benchmarks. The candor of the field is worth preserving here: Noam Shazeer's 2020 paper introducing it for transformers ended with the memorable line that they offer no explanation for why these architectures work and chalk it up, like all else, to divine benevolence. Funny as that is, gated activations have become standard in LLaMA, PaLM, Mistral, and most modern LLMs.

SwiGLU FFN vs ReLU FFNReLU FFN (2 matrices)xLinearW1ReLULinearW2outSwiGLU FFN (3 matrices)xLin W1Lin W2Swish×Lin W3outSwish vs ReLUSwishReLU3 matrices (W1, W2, W3) vs 2 — so SwiGLU shrinks d_ff to ~(8/3)·d_modelto keep the parameter count matched.toggle the gate animation to watch one projection modulate the otherA multiplicative gate lets one projection modulate the other — more expressive per parameter.
gated value = -0.03
Fig 5.21 — A multiplicative gate lets one projection modulate the other — more expressive per parameter.

Check your understanding

SwiGLU uses three weight matrices where the original FFN used two. How do modern models keep the parameter count fair?

Show answer ▸

They shrink the hidden width dffd_{ff}. The original transformer used dff=4×dmodeld_{ff} = 4 \times d_{model}; SwiGLU models typically use about 83×dmodel\tfrac{8}{3} \times d_{model} instead. That smaller hidden dimension, spread across three matrices (W1W_1, W2W_2 for the gate and W3W_3 for the down-projection), lands at roughly the same total parameter count as the original two-matrix ReLU FFN.

Sparse and sliding-window attention

The biggest cost of full attention is its time complexity: O(n2)O(n^2) in the sequence length nn, and it's the single biggest barrier to long-context LLMs. For n=1,000n = 1{,}000, totally fine. For n=1,000,000n = 1{,}000{,}000, you're looking at 101210^{12} operations per layer per head — not fine. Sparse attention patterns trade a little flexibility for huge efficiency gains.

The simplest pattern is sliding-window attention: each token attends only to the ww tokens before it, not all of them. With a window of w=4096w = 4096 and a sequence of n=100,000n = 100{,}000, you do work proportional to n×w=4×108n \times w = 4 \times 10^8 operations instead of n2=1010n^2 = 10^{10} — about a 25× reduction.

But wait — if a token can only see 4,096 tokens back, how does a model ever use a 128k context? Here's the depth trick, and it's lovely. Stack layers. Layer 1 at a given position sees 4k tokens back. Layer 2's output at that position depends on Layer 1's outputs across its own 4k window — each of which already absorbed 4k more tokens back. So after LL layers, the effective receptive field is L×wL \times w tokens, even though every individual layer stays cheap. Information flows like ripples: each layer sees nearby context, but that nearby context already soaked up slightly more distant context from the layer below. With 32 layers and a 4096 window, the effective field is 32×4096=131,07232 \times 4096 = 131{,}072 tokens.

This is exactly the strategy Mistral uses: Mistral 7B has a 4096-token sliding window across 32 layers, giving an effective context of ~128k tokens at modest compute.

Sliding-Window Attention + the Depth Trickone layer: each token attends to w previous tokens, not all nwindow w = 3querystack L layers: the receptive field widens by w each layerL4L3L2L1effective field = L*w = 12L*w = 4 x 3 = 12 tokens of effective contextcost per layer ~ n*w (windowed) vs n^2 (full) — local stays cheapreal model: 32 layers x 4096 window = ~131,072

A small local window per layer, stacked deep, reaches far — that’s how long context stays cheap.

Fig 5.22 — A small local window per layer, stacked deep, reaches far — that's how long context stays cheap.

What sliding window sacrifices: direct attention to faraway tokens. A token at position 50,000 can't directly attend to one at position 100 — the information has to flow up through the layers. For tasks that need exact long-range, pointer-like retrieval, that hurts.

A few richer patterns build on the basic window:

Longformer's pattern: local + global. Longformer added something simple but powerful — a few designated "global" tokens that attend to everything and that everything attends to (typically the [CLS] token or the start of each document). Most tokens use the cheap sliding window; the few special tokens get full attention to and from everyone. This is great when you have a fixed query at the front of the input (like extractive QA): the query tokens get global attention while the document body stays windowed, giving strong performance at sub-quadratic cost.

BigBird: local + global + random. BigBird (2020) is the most elaborate combination — local sliding window, a few global tokens, plus random connections (each token attends to a few random others). The random links are the clever part: they let information hop across the sequence efficiently, like a small-world graph where every node is a few hops from every other. Even though no single token sees everything, after a few layers the information has mixed globally. BigBird was proven to retain the theoretical expressivity of full attention (under mild conditions) while running at O(n)O(n) compute — the cleanest theoretical justification for sparse attention.

And several more you'll run into:

Strided / dilated attention
— attend to every kk-th token in addition to nearby ones.
Block-sparse attention
— divide the sequence into blocks; attend within blocks plus a sparse pattern between them.
Sparse Transformers (OpenAI, 2019)
— strided patterns where each head attends to either nearby tokens or every kk-th token, alternating. Used in early image and music generation.
Sliding window only
— Mistral and many recent models. Simple to implement, works well, easy to combine with Flash Attention.
Hybrid sliding window + full
— alternate layers between sliding window and full attention. Recent models like Gemma 2 and some Llama variants do this: efficiency from the windowed layers, full mixing from the occasional dense one.
Native sparse attention (DeepSeek)
— recent work training models with structured sparsity from scratch, reaching near-full-attention quality at much lower cost.

The theoretical landscape is rich, but the production reality is mostly "sliding window plus full attention every few layers." The 2×–10× speedups are great; chasing every theoretical optimization for another 5% has historically not been worth the engineering pain.

One honest caveat: sliding window doesn't literally extend context for free. Watch for these failure modes — needle-in-a-haystack at long distance (if the answer is at position 1,000 and the question at position 50,000, windowed models can struggle even with enough effective receptive field, because the info has to survive propagation through many layers without being overwritten); multi-hop reasoning over long contexts (chains of references spanning the whole document degrade); and long-range copying (copying a specific phrase from far back gets harder). That's why pure sliding window is often paired with periodic full-attention layers, or with attention sinks (always attend to the first few tokens) to soften these effects.

Check your understanding

If each layer only attends to a 4,096-token window, how can a 32-layer model effectively use ~128k tokens of context?

Show answer ▸

The depth trick. Each layer's output at a position summarizes its own 4k window — but the tokens in that window already summarized their 4k windows in the layer below. Stacking LL layers compounds this, so the effective receptive field grows to about L×wL \times w (32×4096131,07232 \times 4096 \approx 131{,}072). Information ripples upward through the stack even though no single layer ever attends beyond its local window. The cost: faraway tokens are reached only indirectly, which can hurt exact long-range retrieval.

KV cache (important!)

The KV cache is the single most important inference-time optimization in LLMs. Here's the setup. When you generate text autoregressively, you produce one token at a time. To generate token n+1n+1, you compute attention for position nn — and attention needs the KK and VV vectors of every previous position too.

The naive approach recomputes KK and VV for every position from scratch at each step. That's O(n2)O(n^2) work over the whole generation, and most of it is redundant: the KK and VV for position 4 don't change when you're generating position 100. Hugely wasteful.

The fix is to cache them. Once you compute the KK and VV vectors for a token, save them. When generating the next token, only compute QQ, KK, VV for the new position, then concatenate with the cached KK, VV from before. Now attention does work proportional to nn per step instead of n2n^2. This optimization is so fundamental that every modern inference engine has it. The "KV cache" is just two big tensors of shape (num_layers, num_kv_heads, sequence_length, head_dim) — one for K, one for V — growing by one row per generated token.

How big does it get? The KV cache size is:

KV cache=2×num_layers×num_kv_heads×seq_len×head_dim×bytes_per_param\text{KV cache} = 2 \times \text{num\_layers} \times \text{num\_kv\_heads} \times \text{seq\_len} \times \text{head\_dim} \times \text{bytes\_per\_param}

The leading 22 counts both K and V. Plug in a 70B model with 80 layers, 8 KV heads, head_dim 128, in FP16 (2 bytes), at sequence length 100k:

2×80×8×100,000×128×232 GB2 \times 80 \times 8 \times 100{,}000 \times 128 \times 2 \approx 32 \text{ GB}

That's for a single user's context. Now multiply by the number of concurrent users on a server and you see why KV cache management is the single biggest pressure point in production LLM serving. Innovations have piled up to deal with it:

GQA / MQA
— fewer KV heads, smaller cache (covered right below).
Quantized KV cache
— store K and V in 8-bit or even 4-bit precision.
PagedAttention (vLLM)
— manage the cache in pages, like virtual memory, so concurrent requests can share GPU memory efficiently.
Prefix caching
— when many requests share the same system prompt, cache its K and V once and reuse across requests.
KV Cachestep 3: generating token at position 3naive — recomputeall K,V every stepcached — reuseK,V for new token onlyt1t2t3O(n²), mostly redundantK-cachek1k2k3V-cachev1v2v3new (computed)cached (reused)O(n) per stepcache = 2 · layers · kv_heads · head_dim · seq_len · bytes= 2 · 80 · 8 · 128 · 8,192 · 2Bper sequence: 2.50 GB× 13 concurrent seqs (70B serving) →32.5 GBCache K and V once, reuse them every step. The cache size is why serving is hard.
step 3/6
Fig 5.23 — Cache K and V once, reuse them every step. The cache size is why serving is hard.

Check your understanding

What exactly does the KV cache store, and why does caching turn per-step attention cost from O(n2)O(n^2) into O(n)O(n)?

Show answer ▸

It stores the key and value vectors of every token processed so far (two tensors of shape num_layers × num_kv_heads × seq_len × head_dim). Without it, generating each new token would recompute K and V for all previous positions — redundant work that grows quadratically over a generation. With the cache, each step computes K and V only for the single new token and concatenates them onto the stored ones, so a step costs work proportional to the current length nn, not n2n^2.

Prefill vs. decode

The KV cache leads to an important distinction in how LLMs actually run, splitting inference into two phases with very different performance characters.

Prefill — processing the entire prompt at once. All prompt tokens are computed in parallel, building up the initial KV cache. This phase is compute-bound and fast per token. It's the latency you wait through before the first generated token appears.

Decode — generating tokens one at a time. Each step does one new token's worth of compute but has to read the entire KV cache (and the model weights) to do it. This phase is memory-bandwidth-bound and slow per token. It's what sets the speed at which the response streams out.

Optimizing each phase needs different tricks. Prefill loves big tensor cores chewing through dense matrix multiplies; decode lives or dies by memory bandwidth.

Prefill vs DecodePrefillparallel8 prompt tokens, one passKV cache: filling...compute-boundsets time-to-first-tokenDecodesequentialone token per stepeach step reads KV cache + weightsmemory-bandwidth-boundsets tokens/secondtimelinetime →prefilldecode stepslimited by: compute (arithmetic units)
phase: prefill

Prompt processing is parallel and compute-bound; generation is sequential and bandwidth-bound.

Fig 5.24 — Prompt processing is parallel and compute-bound; generation is sequential and bandwidth-bound.

Check your understanding

Why is the decode phase memory-bandwidth-bound while prefill is compute-bound?

Show answer ▸

In prefill, all prompt tokens are processed at once, so the GPU does large dense matrix multiplications that saturate its compute units — lots of arithmetic per byte read. In decode, you generate one token at a time, doing only a sliver of arithmetic, but each step must read the entire KV cache and the full model weights from memory. The work is dominated by moving data, not by computing on it, so memory bandwidth is the limiting factor.

Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)

Both of these attack the KV cache size directly by sharing key/value heads.

Multi-Query Attention (MQA) is the radical version: keep all hh query heads, but use only one shared KK and one shared VV across all of them. This shrinks the KV cache by a factor of hh — for an 8-head model, the cache is 8× smaller. The cost: quality drops noticeably, because the model has less flexibility in how different heads can attend.

Grouped-Query Attention (GQA) is the compromise. Group the hh query heads into gg groups, and have each group share one KK and VV. With h=32h = 32 heads and g=8g = 8 groups, you get a 4× smaller KV cache than full multi-head, with almost no quality loss. GQA hits a sweet spot and is now the default in LLaMA 2, LLaMA 3, Mistral, and many production LLMs.

MHA vs GQA vs MQA8 query heads sharing fewer key/value headstop row = Q heads · bottom row = K/V headsMHA8 Q8 K/Vcache: 8/8full sizequality: bestGQA8 Q2 K/Vcache: 2/84x smallerquality: near-bestMQA8 Q1 K/Vcache: 1/88x smallerquality: worseShare K/V across query heads to shrink the cache. GQA is the sweet spot.
mode:K/V heads:2(GQA, cache 2/8, 4x smaller)
Fig 5.25 — Share K/V across query heads to shrink the cache. GQA is the sweet spot.

Check your understanding

GQA sits between MHA and MQA. What does it trade, and why is it usually the default?

Show answer ▸

GQA groups the query heads and lets each group share one set of K/V heads, instead of every head having its own (MHA) or all heads sharing one (MQA). That shrinks the KV cache — e.g. 4× smaller with 32 query heads in 8 groups — while losing almost no quality, because there's still enough K/V diversity for heads to attend differently. MQA shrinks the cache more but visibly hurts quality, so GQA's balance of big memory savings for negligible quality loss makes it the common default.

Flash Attention

Standard attention has a memory problem. The intermediate attention matrix (QKQ K^\top before softmax) has shape (seq_len, seq_len). For seq_len = 100k, that's 10 billion entries — and you need it for every layer and every head. Even at FP16 it eats hundreds of GB. Worse, the standard algorithm writes that giant matrix out to the GPU's main memory (HBM), reads it back for softmax, then reads it again for the multiply with V. All that shuffling thrashes memory bandwidth.

Flash Attention (Tri Dao, 2022) is a re-implementation of attention that's mathematically identical to the original but treats the GPU memory hierarchy with care. The key insight: never materialize the full attention matrix. Instead, process attention in tiles small enough to fit in fast on-chip SRAM, and use a clever online softmax algorithm that computes the right answer without ever seeing all the values at once.

Here's the contrast. The standard algorithm: compute S=QKS = Q K^\top (huge matrix, write to HBM); compute P=softmax(S)P = \text{softmax}(S) (read from HBM, write back); compute O=PVO = P V (read both, write output). Flash Attention instead: load tiles of Q, K, V into SRAM; compute partial scores and partial softmax statistics on each tile; accumulate the output incrementally while updating running statistics; and never materialize the full SS or PP matrix in HBM at all.

The result is 2–4× faster attention for typical sequence lengths and dramatic memory savings — linear in sequence length instead of quadratic. Flash Attention 2 and 3 refine it further with better parallelism and newer hardware features. Modern LLMs essentially all use some Flash Attention variant.

The takeaway is worth underlining because it's a recurring theme in modern ML: same math, radically different performance, purely by being smarter about memory access patterns. Many of the biggest practical gains come not from new algorithms but from better implementations of existing ones.

Flash Attention: Tiling + Online Softmaxprocessed 0/4 K/V tiles — same math, far less HBM trafficstandardbuild full N×N S in HBMflashstream tiles through SRAMHBM (slow, large)SRAM (fast, tiny)S = QKᵀ : N×N storedwrite S → read S → read SO(N²) memory, heavy I/OK1K2K3K4one tile of S in SRAM at a timefull S never materializedonline softmax staterunning max m = 0.00running sum ℓ = 0.00O += exp(S−m)·V, rescaledHBM accesses (lower is better)standard0 unitsflash0 unitsflash skips all read/write of the N×N S matrix in HBM.Never build the full attention matrix —stream tiles through fast on-chip memory instead.
tile 0/4
Fig 5.26 — Never build the full attention matrix — stream tiles through fast on-chip memory instead.

Check your understanding

Flash Attention computes exactly the same result as standard attention. Where does its speedup come from?

Show answer ▸

From memory access, not math. Standard attention materializes the full (seq_len × seq_len) score matrix in slow GPU main memory (HBM) and shuttles it back and forth for softmax and the value multiply. Flash Attention never builds that matrix: it streams small tiles of Q, K, V through fast on-chip SRAM and uses an online-softmax trick to accumulate the correct output incrementally. Far fewer slow-memory reads/writes means 2–4× speed and memory that scales linearly instead of quadratically.

Mixture of Experts (MoE)

Remember that most of a transformer's parameters live in the FFN. So here's the question MoE asks: what if you could have many FFNs but only run a few of them for each token?

The architecture replaces each FFN in the transformer with NN "expert" FFNs plus a router (also called a gating network). For each token, the router picks the top-kk experts (typically k=1k = 1 or 22) and sends the token through only those. The other NkN - k experts sit idle for that token. If you have 8 experts and pick the top 2 per token, the model has 8× the FFN parameters but does only ~2× the work per token — actually less, since the experts are individually smaller than one dense FFN of equivalent total capacity. This is the sparse activation principle: total parameters huge, compute per token small. The router itself is just a small linear layer that produces scores over the experts; you take the top-kk, softmax those chosen scores, and weight the experts' outputs by them.

Mixture of Experts (MoE)router picks top-2 of 8 expert FFNs per tokentokencatRouter(gating net)linear-> 8 scoressoftmax top-2Expert 1w=0.38Expert 2idleExpert 3idleExpert 4idleExpert 5w=0.62Expert 6idleExpert 7idleExpert 8idle8 expert FFNsΣcombineweighted sumoutputparameterstotal: 8 experts = 7.5Bactive/token: 2 experts = 2.1B-> only 28% of params run per token"cat" -> experts 1 + 5load balance (0 picks)0E10E20E30E40E50E60E70E8Many experts, few active per token: huge capacity, small per-token compute.
send token:
Fig 5.27 — Many experts, few active per token: huge capacity, small per-token compute.

MoE gives you a massive capacity boost without the matching compute cost. Mixtral 8×7B has roughly 47B total parameters but activates only about 13B per token. GPT-4 is widely believed to be a large MoE, and most larger frontier models almost certainly are.

But MoE comes with real challenges:

Load balancing
If the router always picks expert 1, the other experts are wasted. Auxiliary "load-balancing losses" during training nudge usage to spread evenly across experts.
Memory
Even though compute is sparse, all experts must be loaded into memory. A 47B-param Mixtral needs roughly 47B params' worth of memory even though it only does 13B params of compute per token.
Training instability
Routing decisions are discrete (top-kk), which makes gradients tricky. Tricks like soft routing, expert-capacity limits, and noise injection keep training stable.
Inference complexity
Routing different tokens to different experts is harder to batch efficiently, so production serving takes careful engineering.

Despite all that, MoE is the path most frontier models have taken. It's how you build a model with effective trillion-parameter capacity without needing a trillion-parameter forward pass.

Check your understanding

An MoE model can have far more total parameters than a dense model yet cost less compute per token. How — and what's the catch on memory?

Show answer ▸

A router sends each token through only the top-kk experts (say 2 of 8), so compute scales with the active experts, not the total. That's why Mixtral 8×7B can hold ~47B parameters but only activate ~13B per token. The catch: even idle experts still have to be resident in GPU memory, so memory usage tracks the total parameter count (~47B), not the active count. You save compute, not memory.

Long context windows

The original transformer handled a few hundred tokens. Modern models routinely handle 200k, 1M, even 10M tokens. Getting there took innovations across every part of the stack — and notice that each one is a tool we've already met:

Positional encodings
Sinusoidal didn't extrapolate. RoPE plus scaling tricks (NTK-aware scaling, YaRN, position interpolation) let models trained at 8k handle 128k or more. ALiBi extrapolates natively.
Attention efficiency
Quadratic attention can't reach 1M tokens (101210^{12} operations per layer per head). Flash Attention cuts the memory cost; sliding-window and sparse attention cut the compute cost.
KV cache compression
At 1M tokens, the KV cache alone can run to hundreds of GB. Quantization (FP8, INT4), eviction strategies (drop old or low-attention tokens), and compression schemes are all active research.
Training data
A model trained only on 4k-token examples won't suddenly use 1M tokens well, even if it technically can. Long-context ability requires curating long documents and using techniques like progressive length training.
Evaluation
"Needle in a haystack" tests check whether a model can retrieve a specific fact placed somewhere in a long context. More demanding tests like RULER and LongBench probe whether models actually reason over long context or merely retrieve.

The honest truth: many models claim long context but degrade substantially as the context fills up. A model advertising 1M-token context might really only use the first 32k well plus the last few thousand. This is improving fast, but it's still a real design consideration when you're building systems.

Check your understanding

Name two distinct techniques from earlier in this guide that combine to make million-token context windows feasible, and what each one fixes.

Show answer ▸

For example: (1) RoPE with frequency-scaling (NTK-aware scaling / YaRN) fixes the positional-encoding problem — sinusoidal encodings didn't extrapolate past training length, RoPE does. (2) Flash Attention and/or sliding-window attention fix the cost problem — full O(n2)O(n^2) attention is impossible at 1M tokens, so Flash cuts memory and windowed/sparse attention cuts compute. KV-cache quantization is a valid third answer, addressing the hundreds-of-GB cache that long contexts create.

Quantization

Quantization is a term you'll hear constantly in inference engineering. It's the process of mapping a large or continuous set of input values to a smaller, finite set of discrete output values. In plain terms: smaller numbers, faster math, less memory. Modern LLMs are deployed with aggressive quantization:

FP32 → FP16 / BF16
— already standard. 2× memory and bandwidth savings, minimal quality impact.
FP8
— newer. Another 2× savings on top of FP16, often used in training on H100s and newer hardware.
INT8
— common for inference. Weights quantized to 8 bits, sometimes activations too.
INT4 / NF4
— aggressive. Used to run large models on consumer hardware. Some quality loss, especially on harder tasks.
Per-group / per-channel quantization
— quantize different parts of the weights with different scales, for better precision than naive uniform quantization.

There are two main approaches. Post-training quantization (PTQ) trains in higher precision and quantizes afterward — cheap but lossy; common tools are GPTQ and AWQ. Quantization-aware training (QAT) simulates quantization during training so the model learns to be robust to it — more expensive, better quality.

Quantization is what makes it possible to run a 70B-parameter model on a single consumer GPU: with INT4, 70B params at 4 bits is about 35 GB, which fits in a 48 GB workstation card. Without it, frontier models would be out of reach for anyone outside a major data center.

Quantizationone weight at FP1616 bits, 16 grid steps shownbits per weight16 bitsFP16dashed = FP32 (32 bits)quantization grid (coarser = fewer bits)true 0.62snap 0.62570B model footprint70e9 × 2 B140 GBfits on:2x 80GB GPUrounding error0.0050(|true − snap|, grows as bits drop)quality95% vs FP32PTQpost-training: quantize a trained model, fast but more errorFewer bits per weight: less memory and faster math,at some cost to precision.
precision
Fig 5.28 — Fewer bits per weight: less memory and faster math, at some cost to precision.

Check your understanding

Roughly how does INT4 quantization let a 70B model fit on a 48 GB GPU, and what's the cost?

Show answer ▸

At 4 bits per parameter, 70B parameters take about 70×109×0.570 \times 10^9 \times 0.5 bytes ≈ 35 GB, which fits in a 48 GB card (versus ~140 GB at FP16). The cost is precision: rounding each weight to one of only 16 levels introduces some error, which can show up as quality loss, especially on harder tasks. Techniques like per-group/per-channel scaling and quantization-aware training reduce that loss.

Speculative decoding

LLM inference is dominated by the decode phase — generating one token at a time, each step limited by memory bandwidth (reading the full KV cache and weights for one token's worth of compute). Speculative decoding is a clever way to claw back speed.

The trick rests on one observation: a small, fast model can propose several tokens cheaply, and a large model can verify them in a single forward pass — which is the same kind of pass the big model would have done for one token anyway. If the proposals are right, you got multiple tokens for the price of one. If they're wrong at some point, you fall back to the big model's prediction at that position. Concretely:

1. A small "draft" model generates kk tokens in sequence.

2. The big model processes those kk tokens in parallel (one forward pass, like prefill).

3. Compare the proposals to the big model's predictions and accept the longest matching prefix.

4. Continue from there with the big model's correction.

When the draft model agrees with the big model most of the time — which is true for easy tokens like punctuation, common words, and predictable completions — you get a 2–3× speedup with no quality loss (the big model's distribution is always what's ultimately honored). This is now standard in production inference engines. Variants include Medusa (multiple prediction heads on the same model), EAGLE (improved drafting with feature reuse), and lookahead decoding (parallel verification of multiple candidate sequences).

Speculative Decodinground 1 · draft proposes k=5 tokensdraftsmallthecatsatonmatproposed one-by-one (cheap)targetbigthecatsatonmatkeptstep: draft model proposes k tokens →
speedup ≈ 5.0×

A small model guesses ahead; the big model checks them all at once. Free speed when guesses are right.

Fig 5.29 — A small model guesses ahead; the big model checks them all at once. Free speed when guesses are right.

Check your understanding

Speculative decoding speeds up generation but is guaranteed not to change the output distribution. Why is the quality preserved?

Show answer ▸

Because the small draft model's tokens are only accepted when they match what the big (target) model would have produced — verification happens against the big model's own predictions, and any mismatch is corrected by the big model at that position. The draft model just lets the big model confirm several easy tokens in one parallel pass instead of one at a time. The final tokens always come from (or are validated against) the target model, so the distribution is identical; only the speed changes.

RAG (Retrieval-Augmented Generation) and tool use

Even the best LLM has limits baked in: it doesn't know facts from after its training cutoff, it can't see your private data, and it can confidently produce plausible-sounding nonsense (hallucinate). Two main approaches fix these by hooking the model up to external systems.

RAG (Retrieval-Augmented Generation) gives the model a search step before it answers. The pipeline: take the user's query, embed it into a vector (using an encoder model — remember those?), search a vector database of document embeddings for the most similar chunks, retrieve the top matches, and stuff that retrieved text into the prompt as context. Now the model answers grounded in real, current, possibly-private documents rather than only its frozen training memory. This is why encoder models never went away — the retrieval step depends entirely on them, and a cross-encoder often reranks the candidates for extra precision.

Tool use (function calling) goes a step further: instead of only retrieving text, the model can call external tools — run a calculator, query a database, hit a web API, execute code — and fold the results back into its reasoning. The model is trained to emit a structured call ("call search(query)"), the system runs it, and the result comes back as more context for the next step. This is the foundation of modern agents.

Both share the same core idea: the transformer is the reasoning engine, but it doesn't have to contain every fact or capability. Wire it to retrieval and tools and you get something current, grounded, and far more capable than the weights alone.

Retrieval-Augmented Generation (RAG)query → embed → retrieve top-k → insert into prompt → generate1. queryHow do plants makefood?encoder2. embedding0.90.20.10.3search3. vector store (similarity)Photosynthesis turns light …1.00Chlorophyll absorbs sunligh…0.98Rayleigh scattering bends b…0.46Rockets burn fuel for thrust0.38Tax forms are due in April0.44top-24. promptContext:• Photosynthesis …• Chlorophyll abs…Q: How do plants…5. LLMgrounded answerRetrieve relevant text first, then answer grounded in it — no retraining needed.
Fig 5.30 — Retrieve relevant text first, then let the model answer grounded in it — no retraining needed.

Check your understanding

RAG depends on a component we covered much earlier in the guide. Which one, and what's its job in the pipeline?

Show answer ▸

It depends on encoder models. The retrieval step embeds both the user's query and the candidate documents into vectors using an encoder, then finds the most similar document chunks by comparing those vectors. (A cross-encoder often reranks the top candidates for higher precision.) That's exactly the "representation learning, not generation" role encoders specialized into — RAG is one of their biggest modern uses.

Pretraining vs. fine-tuning vs. RLHF

Building a modern instruction-following LLM happens in three phases:

Pretraining
— train on a huge corpus (trillions of tokens) with next-token prediction. This is the bulk of the compute. It produces a "base model" that can complete text but doesn't naturally follow instructions.
Supervised fine-tuning (SFT)
— train on curated examples of instructions paired with good responses. This teaches the format of being a helpful assistant.
Reinforcement learning from human feedback (RLHF)
— use human ratings to train a reward model, then optimize the LLM against it. This produces models that are helpful, harmless, and aligned with human preferences. Modern variants include DPO (Direct Preference Optimization), which skips the separate reward model and optimizes directly on preference pairs.
Three Phases of Training an LLMclick a phase to inspect its data, objective, and capabilityPretrainingnext-token pred.trillions of tokens-> base modelSFTsupervised(instr, response)-> instruct modelRLHFpreference RLhuman ratings-> aligned modelshare of total training computepretraining dominatespretraining ~97%SFT + RLHF ~3% (post-training)1. Pretrainingphase 1 of 3data:Web-scale raw text corpusscale:~trillions of tokensobjective:Next-token prediction (self-supervised)adds:World knowledge + languagePretrain for knowledge, SFT for format, RLHF for alignment.
selected: Pretraining

Each phase reuses the previous model's weights: pretraining builds knowledge, SFT teaches the assistant format, and RLHF aligns outputs with human preferences.

Fig 5.31 — Pretrain for knowledge, SFT for format, RLHF for alignment with human preferences.

Check your understanding

What does each of the three training phases contribute, and which one uses the most compute?

Show answer ▸

Pretraining (next-token prediction on trillions of tokens) gives the model its broad knowledge and language ability and uses by far the most compute, producing a base model that completes text. SFT teaches it to follow instructions in an assistant format using curated instruction–response pairs. RLHF (or DPO) aligns it with human preferences — making it helpful and harmless — using human ratings to shape the model's behavior.

LoRA and parameter-efficient fine-tuning, in depth

Let's slow down here, because LoRA is one of the most useful ideas in practical ML — and it ties directly back to the rank concept from linear algebra.

Start with the motivation. The leading LLMs today contain upwards of a trillion parameters, pretrained on tens of trillions of tokens. A model like Gemini 3 is trained once on an enormous internet-scale corpus, which gives it a broad but shallow understanding across many domains. The trouble is that companies paying for these models usually don't want a generalist — they want a specialist that's excellent at their specific task.

So after pretraining comes post-training: small datasets meant to focus the model on a narrower domain of knowledge or a particular range of behavior. And now think about the mismatch. Doesn't it seem absurdly expensive to use a terabit of weights to absorb updates from just a gigabit or megabit of training data? It's a small, narrow lesson being written into an enormous network.

This is exactly where LoRA (Low-Rank Adaptation) comes in. It's a method of parameter-efficient fine-tuning (PEFT) — adjust a large network by updating only a small set of parameters. LoRA is the leading and most popular PEFT method. It works by replacing each weight matrix WW from the original model with a modified version:

W=W+γBAW' = W + \gamma BA

where BB and AA are matrices that together have far fewer parameters than WW, and γ\gamma is a constant scaling factor. In effect, LoRA creates a low-dimensional representation of the updates that fine-tuning imparts.

Take a sec to let that internalize. We're not changing WW at all — we're learning a small, cheap correction to add on top of it.

According to the Thinking Machines blog, LoRA offers advantages in the cost and speed of post-training, plus a few operational reasons to prefer it over full fine-tuning (which we'll call FullFT):

Multi-tenant serving
Since LoRA trains an adapter (the AA and BB matrices) while leaving the original weights untouched, a single inference server can keep many adapters — different specialized versions — in memory and sample from them all in a batched way. Modern engines like vLLM and SGLang implement this. (See Punica: Multi-Tenant LoRA Serving, Chen, Ye, et al., 2023.)
Smaller training footprint
When you fine-tune the whole model, you also have to store the optimizer state alongside the weights, often at higher precision (float32) than the bfloat16-or-lower used for inference. You need gradients and optimizer moments for all the weights. As a result, FullFT usually needs an order of magnitude more accelerators than just sampling from the same model does — a different hardware layout entirely. Because LoRA trains far fewer weights and uses far less memory, it can run on a layout only slightly bigger than what you'd use for sampling. That makes training more accessible and often more efficient.
Easy loading and transfer
With far fewer weights to store, LoRA adapters are quick to set up or move between machines.

These reasons explain LoRA's surging popularity since the original paper (LoRA: Low-Rank Adaptation of Large Language Models, Hu et al., 2021). Still, the literature was for a while unclear on how well LoRA performs relative to FullFT — which we'll get to.

What "rank" actually means. From linear algebra, the rank of a matrix is the dimension of the vector space spanned by its columns (or rows — they're always equal). It's the maximum number of linearly independent row or column vectors in the matrix. Think of it as a measure of the unique information in the matrix. Hold onto that, because it's the whole key.

In the original paper, the authors realized that the change matrix produced by fine-tuning has a low intrinsic rank. Meaning: even though that change matrix may be huge, the unique information stored in it actually lives in a tiny subspace. So the matrix can be closely approximated from a representation space much smaller than the matrix itself. In essence, the fine-tuning update can be stored with a small amount of data.

LoRA exploits this with a decomposition trick. It operates on a delta weight matrix produced by two skinny matrices AA and BB. Say you had a weight matrix of shape (2048, 2048) — that's 4,194,304 parameters. After fine-tuning, the original matrix stays frozen. LoRA approximates the change as the product of two much skinnier matrices: a column-shaped BB of shape (2048, rr) and a row-shaped AA of shape (rr, 2048). The number rr — the rank of the decomposition — decides how much of the original matrix's structure the approximation can retain. It's always small: 8, 16, 32, or 64. When you multiply AA and BB, the result is back to the full (2048, 2048) shape, but it was reconstructed from 2×r×20482 \times r \times 2048 numbers instead of 204822048^2. For r=16r = 16, that's 2×16×20482 \times 16 \times 2048 parameters — a tiny fraction.

LoRA: Low-Rank Decomposition of the UpdateW' = W + (alpha / r) · B · AW2048×2048frozen4,194,304+ΔW2048×2048the updateΔW = (alpha/r)· B · AΔW=trainableB2048×16·A16×2048trainable paramsd² = 4,194,3042·d·r = 65,536 (1.56%)Freeze W, learn a tiny low-rank correction B·A — its realinformation fits in a small subspace.
smaller r = smaller adapter, faster, less capacity
Fig 5.32 — Freeze W, learn a tiny low-rank correction B*A. The update's real information fits in a small subspace.

As for those billions of base-model parameters, they stay frozen, in the same form as after pretraining. During LoRA fine-tuning, it's the two small matrices that get learned. At inference, the model uses W+BAW + BA rather than WW alone. This combination is called an adapter: the base model's behavior is preserved and improved by the adapter. Different adapters, trained for different tasks, produce different (B,A)(B, A) pairs — all attaching to the same frozen base. One fixed model, many small additions.

The obvious question: how do you choose the rank? A smaller rank means a smaller adapter and faster inference, but less capacity to capture what fine-tuning is trying to teach. A larger rank means higher cost and slower inference. Picking rr is a quality-vs-cost tradeoff, and the right value depends on the task.

There's also the hyperparameter α\alpha (alpha), which scales how strongly the adapter modifies the base. The computation is really W+(α/r)BAW + (\alpha / r) \cdot BA, with α\alpha conventionally set to twice the rank. Its purpose is to give a stable way to control the adapter's influence independently of rr, since the raw magnitude of BABA would otherwise swing a lot as you change the rank.

Connecting back to attention: LoRA is classically applied to the query and value projections, the matrices we called q_projq\_proj and v_projv\_proj.

A concrete sizing example. Llama-3.1-8B has 32 layers, so an adapter targeting q_projq\_proj and v_projv\_proj contains 64 individual (B,A)(B, A) pairs in total — one pair per targeted matrix per layer. Each pair contributes on the order of 100,000 parameters at rank 16, so the full adapter is a few million parameters and occupies about 13 MB on disk in fp16. Compared against the ~16 GB needed to store a fully fine-tuned copy of the same base model, that's a ratio of roughly 1,200 to 1.

One implementation wrinkle worth knowing: modern Llama models use Grouped-Query Attention (GQA), under which the value matrix is smaller than the query matrix. As a result, the BB matrix on v_projv\_proj has shape (1024, 16) rather than (4096, 16) — which is why the adapter lands at 13 MB rather than the rounder 32 MB that naive shape arithmetic would predict. (Nice to see GQA show up again, isn't it? The pieces interlock.)

A final property to emphasize: an adapter is permanently bound to a specific base model. A LoRA trained on Llama-3.1-8B will only work with Llama-3.1-8B, because the shapes of BB and AA are determined by that base model's weight dimensions. A Llama adapter can't be applied to Qwen, or even to a different size of Llama. That constraint is exactly what makes multi-tenant serving coherent — every adapter on the server attaches to the same known base.

(And QLoRA combines LoRA with quantization — fine-tuning a 4-bit-quantized base with LoRA adapters on top — for even cheaper fine-tuning on consumer hardware.)

Check your understanding

LoRA freezes WW and learns W=W+(α/r)BAW' = W + (\alpha/r)BA with rr small. What linear-algebra insight makes this work, and what does α\alpha do?

Show answer ▸

The insight is that the update produced by fine-tuning has low intrinsic rank — its unique information lives in a tiny subspace, so it can be well-approximated by the product of two skinny matrices BB (d×r) and AA (r×d) with rr much smaller than the matrix dimension. You reconstruct a full-size update from only 2×r×d2 \times r \times d trainable numbers. α\alpha scales the adapter's influence: using (α/r)(\alpha/r) keeps the effective update magnitude stable as you change rr, so the optimal learning rate doesn't shift much with rank.

Findings from Thinking Machines: "LoRA Without Regret"

For a while the open question was whether LoRA could actually match full fine-tuning. The Thinking Machines work, "LoRA Without Regret," answers it with refreshing specificity. The whole article condenses into two requirements for matching FullFT.

Condition 1 — apply LoRA to all layers, especially the MLP/MoE layers that hold most of the parameters. Attention-only LoRA underperforms even when you match the number of trainable parameters by cranking up the rank. Concretely, on Llama-3.1-8B, attention-only at rank 256 (0.25B params) underperforms MLP-only at rank 128 (0.24B params) despite roughly equal parameter counts — so the gap isn't a parameter-count issue, it's where you apply the adapter. They also found that applying LoRA to the attention matrices shows no extra benefit beyond applying it to the MLPs alone. (This is a nice twist on the original paper's advice to target q_projq\_proj/v_projv\_proj — at scale, the FFN/MLP is where the action is, which lines up with the fact that the FFN holds most of the parameters.)

Condition 2 — stay out of the capacity-constrained regime. Keep the number of trainable parameters above the information content of the dataset. When a dataset exceeds LoRA's capacity, LoRA doesn't slam into a hard loss floor; instead it shows worse training efficiency, depending on the ratio of model capacity to dataset size. Lower-rank adapters "fall off" the optimal loss curve once they run out of capacity.

On the practical settings, several findings are worth stating flat out:

The optimal learning rate for LoRA is consistently about 10× higher than for FullFT
, across both supervised learning and RL. This 10× ratio showed up in every U-shaped plot of performance against learning rate, and their multi-model fit landed on a multiplier of 9.8. That makes transferring a known FullFT learning rate to LoRA almost mechanical. For very short runs (under ~100 steps), preliminary evidence suggests a higher multiplier around 15×, converging to 10× for longer runs.
The optimal learning rate is approximately independent of rank
, thanks to the 1/r1/r scaling in the W=W+(α/r)BAW' = W + (\alpha/r)BA parametrization. The optimal LR changes by less than a factor of 2 between rank 4 and rank 512, though rank 1 wants a somewhat lower LR. Early in training, the learning curves for different ranks are nearly identical.
One caution
LoRA is in some settings less tolerant of large batch sizes than FullFT, with the loss penalty growing as batch size increases — and raising the rank does not fix it. They attribute this to the optimization dynamics of the BABA product parametrization rather than to a capacity limit.
  • For their settings they used α=32\alpha = 32 and the standard Hugging Face PEFT initialization — a uniform distribution for AA scaled by 1/din1/\sqrt{d_{in}}, zero initialization for BB, the same learning rate for both matrices — and reported they couldn't improve on these. A useful simplification: although LoRA nominally has four hyperparameters (α\alpha, LRALR_A, LRBLR_B, initA\text{init}_A), invariances in the training dynamics mean only two degrees of freedom actually matter.

The standout RL result. The most striking finding for reinforcement learning: LoRA fully matches FullFT for policy-gradient RL even at ranks as low as 1. The reasoning is information-theoretic. Policy-gradient methods learn from the advantage function, which provides only O(1)O(1) bits per episode — roughly 1000× less information per token than supervised learning. In their MATH example, training on ~10,000 problems with 32 samples each needs to absorb about 320,000 bits, while a rank-1 LoRA on Llama-3.1-8B already has 3M parameters — nearly 10× that capacity. They also observed that LoRA has a wider band of well-performing learning rates in RL. The lesson: when the learning signal is thin (as in RL), you barely need any adapter capacity at all.

LoRA Without Regret: When LoRA Matches Full Fine-Tuningloss vs learning rateloss ↑log10(learning rate) →1e-51e-41e-31e-2~10xFullFTLoRALoRA matches / leads hereFindings (click to expand)1. Apply to ALL layers+2. Optimal LR ~ 10x FullFTfit: 9.8x offset~15x for very short runssee U-curves at left3. LR ~independent of rank+4. Only 2 hyperparameters+5. Stay above info content+RL: rank-1 LoRA matches FullFT~O(1) bits/episode — ~1000x less info than SFT320,000 bits needed vs 3M params availableLoRA matches full fine-tuning if you cover all layers and keepenough capacity — and for RL, rank 1 is plenty.
Fig 5.33 — LoRA matches full fine-tuning if you cover all layers and keep enough capacity — and for RL, rank 1 is plenty.

Check your understanding

Why does rank-1 LoRA suffice to match full fine-tuning for policy-gradient RL, but not always for supervised learning?

Show answer ▸

It's about how much information the training signal carries. Policy-gradient RL learns from the advantage function, which delivers only about O(1)O(1) bits per episode — roughly 1000× less information per token than supervised learning. So there's very little to "store," and even a rank-1 adapter (a few million parameters, e.g. 3M on Llama-3.1-8B versus the ~320,000 bits needed in their MATH example) has ample capacity. Supervised fine-tuning pushes far more information into the weights, so it can exceed a tiny adapter's capacity and require higher rank.

Scaling laws

Empirical research (Kaplan et al., and then Chinchilla) found that LLM performance follows predictable power laws in model size, dataset size, and compute. The Chinchilla scaling law is especially influential: for a given compute budget, the optimal model size and dataset size grow together at specific rates. Earlier models like GPT-3 were "undertrained" by Chinchilla standards — for their parameter count, they should have been trained on more data. Modern models (LLaMA, Mistral, and others) take this seriously, training smaller models on far more tokens. The practical upshot is that "make it bigger" isn't the whole story — you have to scale data alongside parameters to spend compute optimally.

Check your understanding

What did the Chinchilla scaling law reveal about models like GPT-3?

Show answer ▸

That they were undertrained for their size. Chinchilla showed that, for a fixed compute budget, model size and training-data size should grow together at specific rates — and GPT-3 had too many parameters relative to the number of tokens it saw. The takeaway reshaped modern training: prefer smaller models trained on far more data, rather than just inflating parameter counts.

Constitutional AI and RLAIF

Human feedback (the "HF" in RLHF) is expensive and slow. Constitutional AI offers an alternative: instead of relying on humans to rate every output, you have the AI critique its own outputs against a written set of principles — a "constitution." The model rewrites its responses to be more aligned, then trains on those rewrites. This is how Claude was trained, and the broader family of approaches — RLAIF (RL from AI Feedback) — is now common in modern alignment work. The idea scales feedback the way pretraining scaled supervision: let the model generate the signal it learns from, guided by principles rather than per-example human labels.

Check your understanding

What does Constitutional AI replace in the standard RLHF recipe, and how?

Show answer ▸

It replaces (much of) the expensive human feedback. Instead of humans rating outputs to train a reward model, the model critiques and revises its own responses against a written set of principles (a "constitution"), then trains on those self-revisions. This is the basis of RLAIF — RL from AI Feedback — and it's how Claude was trained.

Multimodal models

Modern frontier models are no longer text-only. Vision is the most common extension: an image encoder (often a Vision Transformer) produces image tokens that get interleaved with text tokens in the same transformer. The model treats images as just another modality of input. Audio, video, and even more exotic modalities (proteins, code-execution traces) work the same way.

The unifying insight is genuinely deep: anything you can tokenize and embed, a transformer can attend to. The architecture is modality-agnostic — the same attention-over-a-sequence machinery you've now fully understood doesn't care whether the tokens came from words, image patches, or audio frames. That's why the vision side of these models is, as we noted earlier, an encoder feeding into the decoder LLM.

Check your understanding

What makes the transformer architecture able to handle images, audio, and text with essentially the same machinery?

Show answer ▸

The architecture only ever operates on a sequence of embedded tokens and attention between them — it doesn't care where those tokens came from. As long as you can tokenize and embed a modality (image patches via a Vision Transformer, audio frames, etc.) into vectors, the same self-attention machinery can mix them, even interleaved with text tokens. The transformer is modality-agnostic.

Inference engines

Production LLM serving doesn't use raw PyTorch — it uses specialized inference engines. The big ones:

vLLM
— PagedAttention, high throughput, dynamic batching.
TensorRT-LLM
— NVIDIA's heavily optimized engine.
SGLang
— flexible structured generation with constraint enforcement.
llama.cpp
— runs quantized models on CPUs and consumer GPUs.

These typically hit 5–10× the throughput of naive PyTorch inference, through continuous batching, paged KV cache, fused kernels, and quantization — which is to say, through exactly the optimizations we've been walking through this whole section, packaged up and engineered hard.

Check your understanding

Production inference engines like vLLM get ~5–10× the throughput of naive PyTorch. Name two of the techniques (covered earlier) that get them there.

Show answer ▸

Any two of: paged KV cache (PagedAttention — managing the cache like virtual memory so concurrent requests share GPU memory), continuous/dynamic batching, fused kernels (e.g. Flash Attention), and quantization. They're not new algorithms — they're the inference optimizations from this section, implemented carefully and combined.

Generative Models (Beyond the Transformer)

The transformer isn't the only way to generate data — and the other major families are worth understanding, because they each take a fundamentally different angle on the same goal: learn the structure of data well enough to produce new examples. We'll cover three: GANs, autoencoders, and VAEs. They build on each other in a clean progression, much like the history section did.

Generative Adversarial Networks (GANs)

The cleanest way to understand a GAN is with an analogy. Imagine a counterfeiter trying to print fake banknotes and a detective trying to catch them. At first the counterfeiter is terrible and the detective spots every fake. But each time the detective rejects a note, the counterfeiter learns a little about what gave it away and improves. And each time the counterfeiter improves, the detective has to get sharper too. They improve together. If this arms race runs long enough, the counterfeiter's fakes become so good that the detective can do no better than flip a coin — at which point the fakes are, by definition, indistinguishable from real money.

That's exactly what a GAN does: two neural networks locked in competition.

1. A Generator turns random noise into fake data, trying to fool its opponent.

2. A Discriminator looks at a sample and judges real or fake, trying not to be fooled.

The generator never sees a real image — it learns entirely from the discriminator's reactions. (And note: because it's not minimizing a reconstruction error against a target, there's none of the averaging-toward-blur that reconstruction losses tend to cause.)

Let's make the two networks precise. The **generator GG** is a function from noise to data: feed it a random vector zz (usually drawn from a simple Gaussian) and it outputs something the same shape as a real sample — an image, say. Different noise vectors give different outputs, so once trained, GG is your sampler: pick fresh noise, get a fresh image. The **discriminator DD** is just a binary classifier. It takes a sample and outputs a single number between 0 and 1: the probability that the sample is real. The detail that makes it all work is the feedback path — the generator only ever improves by chasing the discriminator's verdict. It has no other teacher.

Training alternates between two phases, and keeping them straight is the key to understanding GANs. In each phase you freeze one network and train the other. Why freeze one at a time? Because they have opposite goals — if you moved both at once they'd fight over the same gradient and nothing would stabilize. Alternating lets each adapt to the other's current skill level. And notice the target: success isn't "DD wins" or "GG wins," it's a stalemate where DD is reduced to a coin flip.

GAN: Generator vs Discriminatorstep 1/6 — Phase 1: freeze G, train Dnoise z[0.7 -0.3 0.5 -0.8]Ggeneratorfake samplereal samplefrom datasetDdetectiveP(real)real: 0.92fake: 0.08gradient back to G — D is G's only teachertrue labelsD accuracy: 0.950.5 targetTwo networks in an arms race; the goal is a stalemate where the detective can only guess.
acc 0.95 · training Dkeep stepping toward 0.5
Fig 5.34 — Two networks in an arms race; the goal is a stalemate where the detective can only guess.

The mechanics of how GANs work

Everything above is captured by one equation — the value function that DD wants to push up and GG wants to push down:

minG maxD V(D,G)=Expdata[logD(x)]+Ezpz[log(1D(G(z)))]\min_{G}\ \max_{D}\ V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}\big[\log D(x)\big] + \mathbb{E}_{z \sim p_z}\big[\log(1 - D(G(z)))\big]

Let's break down every symbol:

  • xpdatax \sim p_{\text{data}} — a real sample xx drawn from the true data distribution pdatap_{\text{data}}.
  • zpzz \sim p_z — a random noise vector zz drawn from a simple prior pzp_z (typically a Gaussian).
  • G(z)G(z) — a fake sample, produced by running noise through the generator.
  • D()D(\cdot) — the discriminator's estimated probability that its input is real (between 0 and 1).
  • E[]\mathbb{E}[\cdot] — the expected value (the average over many samples).
  • minGmaxD\min_G \max_D — out front, this just says "DD tries to maximize VV; GG tries to minimize it."

The first term, logD(x)\log D(x), is over real samples. DD wants D(x)D(x) close to 1 here (it's real, so call it real), which makes logD(x)\log D(x) close to 0 instead of very negative. So DD maximizes this term by correctly trusting real data.

The second term, log(1D(G(z)))\log(1 - D(G(z))), is over fakes. DD wants D(G(z))D(G(z)) close to 0 (it's fake, so call it fake), which again pushes the term up. But GG touches only this term, and GG wants the opposite: D(G(z))D(G(z)) close to 1 — fakes that pass as real. That single shared term, pulled in two directions, is the adversarial game.

Why does this converge to reality? Goodfellow's 2014 paper proves the satisfying part. If you hold GG fixed and find the best possible discriminator, it turns out to be:

D(x)=pdata(x)pdata(x)+pg(x)D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)}

where pgp_g is the distribution of the generator's outputs. This is intuitive: the ideal detective's confidence at a point is just the fraction of stuff there that's genuinely real. Now substitute DD^* back into the value function, and the generator's objective simplifies to minimizing the Jensen–Shannon divergence between pgp_g and pdatap_{\text{data}} — a measure of how different two distributions are. That divergence hits its minimum at exactly one place:

pg=pdatap_g = p_{\text{data}}

The generator's distribution has become the real data distribution. And right there, D(x)=12D^*(x) = \tfrac{1}{2} everywhere — the coin flip. The math says that perfectly played, this game recovers the true data distribution.

Why GAN Training ConvergesD* = 1/2p_datap_gD*(x)10x (1D sample space) →D*(x) = p_data / (p_data + p_g)JSD = 0.893 bitsstep 0%moving toward equilibriumThe optimal discriminator becomes a coin flipexactly when the fakes match reality.
Fig 5.35 — The optimal discriminator becomes a coin flip exactly when the fakes match reality.

Where GANs get hard

The clean theory assumes perfect play and infinite capacity. Reality is messier, and three issues are worth knowing.

Vanishing gradient / saturation. This shows up early in training. When GG is bad, DD rejects its fakes with total confidence, which means the term log(1D(G(z)))\log(1 - D(G(z))) flattens out and gives GG almost no gradient to learn from — the forger gets no useful feedback exactly when it needs it most. The standard fix is to train GG to maximize logD(G(z))\log D(G(z)) instead (the "non-saturating" loss). Same goal — fool DD — but with strong gradients when GG is struggling. (Notice this is our recurring vanishing-gradient villain again, in yet another costume.)

Mode collapse is the most famous GAN failure. GG discovers that one particular output reliably fools the current DD, so it just keeps producing that one thing (or a few). It's "winning" the game while ignoring most of the data's variety — imagine a counterfeiter who only ever makes flawless $20 bills and never learns the other denominations. The samples look real but lack diversity.

Instability comes from the fact that you're not minimizing a fixed loss — you're chasing a moving equilibrium between two networks. If DD gets too strong too fast, GG's gradients die; if GG overshoots, DD scrambles to catch up. Training can oscillate instead of settling. A lot of GAN research — Wasserstein GAN, spectral normalization, gradient penalties, careful learning-rate balancing — exists precisely to tame this.

Check your understanding

What is "mode collapse," and why does the standard GAN objective get swapped for the "non-saturating" loss?

Show answer ▸

Mode collapse is when the generator finds one (or a few) outputs that reliably fool the current discriminator and just keeps producing those, ignoring the variety in the real data — realistic but not diverse. The non-saturating loss fix is a separate issue: with the original log(1D(G(z)))\log(1 - D(G(z))) term, when GG is bad and DD rejects everything confidently, that term flattens and gives GG almost no gradient (saturation). Training GG to maximize logD(G(z))\log D(G(z)) instead gives strong gradients precisely when GG is struggling, so it can actually learn early on.

Autoencoders

The goal of an autoencoder is all about representation — representation learning. It tries to find a way to squeeze data, like a 784-pixel image, down into a small set of numbers zz that capture what matters. That small set lives in a compact space called the latent space.

At a high level, an autoencoder is a neural network trained to copy its input to its output, through two halves:

1. **Encoder ff** — compresses input xx into a latent code z=f(x)z = f(x), where zz is much smaller than xx.

2. **Decoder gg** — reconstructs the input from the code, x^=g(z)\hat{x} = g(z).

The squeeze in the middle is the whole point. If zz had the same size as xx, the network could just copy numbers straight through and learn nothing. By forcing zz to be small, we make it keep only the essential structure of the data.

Training simply minimizes how different the reconstruction is from the original. For continuous data we use squared error:

L(θ,ϕ)=1Ni=1Nxigθ ⁣(fϕ(xi))2\mathcal{L}(\theta, \phi) = \frac{1}{N}\sum_{i=1}^{N} \big\lVert\, x_i - g_\theta\!\big(f_\phi(x_i)\big) \big\rVert^2

The symbols:

  • xix_i — the ii-th training example.
  • ϕ\phi (phi) — the encoder's weights.
  • θ\theta (theta) — the decoder's weights.
  • fϕ(xi)f_\phi(x_i) — the encoder applied to xix_i, producing the latent code zz.
  • gθ()g_\theta(\cdot) — the decoder applied to that code, producing the reconstruction x^\hat{x}.
  • 2\lVert \cdot \rVert^2 — the squared distance between the original and its reconstruction.
  • 1N\frac{1}{N}\sum — average this over all NN examples.

We backpropagate this loss through both halves at once. That's it — there's no label, just the input acting as its own target, which is why this is called self-supervised learning.

Autoencoderx → encode f → latent z → decode g → x̂x784-px inputencoder fz = 4latentdecoder greconstructionloss = ||x − x̂||² = 0.050compression: 94% smaller than the 64-dim inputtight code: detail is dropped, only essentials survive
squeeze the input through a small code and rebuild it
Fig 5.36 — Squeeze the input through a small code and rebuild it — the squeeze forces it to keep only what matters.

What's it good for? Dimensionality reduction (a nonlinear cousin of PCA), denoising (feed in a corrupted xx, train it to output the clean xx), and anomaly detection (things it can't reconstruct well are unusual).

But here's the catch that motivates everything next. A plain autoencoder learns a code zz, but it learns *nothing about how zz is distributed*. The latent space is full of holes. If you pick a random zz and decode it, you usually get garbage — because the decoder only ever saw the specific scattered points the encoder happened to produce. So a vanilla autoencoder cannot generate new data reliably. Fixing exactly that is the job of the VAE.

Check your understanding

Why can't a plain autoencoder reliably generate new data, even though it reconstructs its training data well?

Show answer ▸

Because it only learns to map specific inputs to specific latent points and back — it learns nothing about how the latent codes are distributed. The latent space ends up as scattered points with large empty gaps between them. If you pick a random point (especially in a gap) and decode it, the decoder has never seen anything like it and produces garbage. Reconstruction works on points the encoder actually produced; generation requires the whole space to be meaningful, which a plain autoencoder doesn't guarantee.

The Variational Autoencoder (VAE)

The VAE keeps the encoder–decoder shape but reframes everything probabilistically, so the latent space becomes smooth and samplable — which is exactly what plain autoencoders lacked.

The key change: instead of mapping xx to a single point zz, the encoder maps xx to a distribution over zz — specifically a Gaussian with a mean μ\mu and a spread σ\sigma. We then sample zz from that little cloud and decode it.

Here's the intuition, and it's worth picturing carefully. A plain autoencoder learns to squash each image down to a short code (a single point), then rebuild it. The trouble is what the space of codes looks like. Picture every training image getting assigned a dot on a map. The autoencoder only ever learns about the exact dots it placed — and there are huge empty gaps between them. If you stand in a gap and ask the decoder "what's here?", it has no idea, and you get garbage. That's why a plain autoencoder can't generate: there's nowhere safe to stand.

A VAE fixes this with one move: instead of placing each image at a single dot, it places each image as a little fuzzy cloud, and it forces all the clouds to crowd together into one tidy region with no gaps. Now every point on the map sits inside some cloud, so every point decodes to something sensible.

Why a VAE can generate: dots vs cloudsclick a map to sample the latent space and decode that pointplain autoencoder: dots01234VAE: overlapping clouds01234click left map to sampleclick right map to sampleDots leave gaps you can't generate from; overlapping cloudsfill the space smoothly.
Fig 5.37 — Dots leave gaps you can't generate from; overlapping clouds fill the space smoothly.

How the VAE works

Now the only math you really need. The encoder, instead of outputting a code directly, outputs two things that describe a cloud:

  • μ\mu (mu) — the center of the cloud, i.e. where this image roughly lives on the map.
  • σ\sigma (sigma) — the width of the cloud, i.e. how fuzzy or spread out it is.

Then the "sample" step picks one random point from that cloud. In symbols:

zN(μ, σ2)z \sim \mathcal{N}(\mu,\ \sigma^2)

which reads "draw zz randomly from a bell-shaped (Gaussian) cloud centered at μ\mu with variance σ2\sigma^2." That's the entire mathematical content of the forward pass. (There's one small technical trick to make this trainable — the reparameterization trick — where you write the random point as z=μ+σεz = \mu + \sigma \cdot \varepsilon, with ε\varepsilon being the random part drawn from a standard Gaussian. This pushes the randomness "off to the side" so gradients can still flow through μ\mu and σ\sigma during backpropagation.)

Training balances two pulls, and the whole VAE is really just the tug-of-war between them.

Pull #1 — "rebuild it correctly." This is the same reconstruction goal as a plain autoencoder: the output x^\hat{x} should match the input xx. On its own, this pull wants each image to get a very precise, tightly pinned location so it can be rebuilt perfectly — which would recreate the scattered-dots problem all over again.

Pull #2 — "stay near the center and stay fuzzy." This is a penalty whose formal name is the KL term (Kullback–Leibler divergence). It measures how far each cloud has drifted from a standard reference cloud sitting at the center of the map — a unit Gaussian, N(0,1)\mathcal{N}(0, 1). For a Gaussian encoder output, this term has a clean closed form:

DKL(N(μ,σ2)  N(0,1))=12j=1d(μj2+σj2logσj21)D_{\text{KL}}\big(\mathcal{N}(\mu, \sigma^2)\ \|\ \mathcal{N}(0, 1)\big) = \frac{1}{2}\sum_{j=1}^{d}\Big(\mu_j^2 + \sigma_j^2 - \log \sigma_j^2 - 1\Big)

Let's read every piece, since this is the term that does the magic:

  • μj\mu_j — the center of the cloud along latent dimension jj. The μj2\mu_j^2 term punishes the cloud for wandering away from the origin — the farther off-center, the bigger the penalty. This is what crowds all the clouds toward the middle.
  • σj2\sigma_j^2 — the variance (width) of the cloud along dimension jj. The σj2logσj2\sigma_j^2 - \log \sigma_j^2 combination is minimized when σj2=1\sigma_j^2 = 1: if the cloud shrinks toward a sharp point (σj20\sigma_j^2 \to 0), the logσj2-\log \sigma_j^2 term blows up and punishes it; if it spreads too wide, the +σj2+\sigma_j^2 term punishes it. So this keeps every cloud puffy — neither a sharp dot nor a smeared mess.
  • The 1-1 is just a constant that makes the whole expression equal exactly 0 when the cloud is a perfect match to the reference (μj=0\mu_j = 0, σj2=1\sigma_j^2 = 1).
  • j=1d\sum_{j=1}^{d} — sum this over all dd latent dimensions; 12\frac{1}{2} is a scaling factor that falls out of the math.

So Pull #2, in one breath: it constantly nudges every cloud back toward the center of the map and keeps it from collapsing to a sharp point.

Put the two pulls together. Pull #1 wants distinct, precise locations; Pull #2 wants everything soft and piled at the center. The compromise they settle on is exactly the right picture — clouds different enough to rebuild their own images, but overlapping and centered enough to leave no gaps. The map fills in smoothly. That balancing act is the VAE's whole secret.

Once training is done, the encoder has served its purpose. To generate something brand new, you skip it entirely: pick a random point from the center of the map (sample from N(0,1)\mathcal{N}(0,1)) and hand it to the decoder.

VAE: Generation and the Two Pullsforward pass (reparameterization: ε ~ N(0,1))xinputencoderμ, σcloudz = μ+σεsampledecoderreconlatent map (2-D)N(0,I) priorPull #1 reconstruction (blue)→ park cloud tight & off-centerPull #2 KL term (coral)→ pull toward N(0,I), stay puffyβ (KL weight) = 0.40μ = 1.10 σ = 0.52KL = 1.78recon quality = 84%KL = ½ Σ ( μ² + σ² − log σ² − 1 )μ²: off-center penaltyσ²−log σ²: anti-collapse (wants σ→1)generation mode — encoder discarded, sample from N(0,I) & decodez ~ N(0,I)decoderpress “generate” →Encode to a cloud, balance reconstruction against the KL pull,then generate by decoding random center points.
Fig 5.38 — Encode to a cloud, balance reconstruction against the KL pull, then generate by decoding random center points.

The one cost of all this softness: VAE images tend to come out slightly blurry. Because each image is represented as a spread-out cloud rather than a precise point, the decoder is effectively averaging over a little neighborhood, which smooths out fine detail. That blurriness is the main reason sharper methods like GANs and diffusion exist.

Check your understanding

In the VAE's KL term, the μj2\mu_j^2 and the σj2logσj2\sigma_j^2 - \log\sigma_j^2 pieces each do a specific job. What are they, and why does this make generation possible?

Show answer ▸

μj2\mu_j^2 penalizes a cloud for drifting away from the origin, so it pulls every cloud toward the center of the latent map. σj2logσj2\sigma_j^2 - \log\sigma_j^2 is minimized at σj2=1\sigma_j^2 = 1: it punishes clouds that collapse to a sharp point (via logσj2-\log\sigma_j^2 blowing up) and clouds that spread too wide (via +σj2+\sigma_j^2), keeping each one appropriately "puffy." Together they crowd overlapping, fuzzy clouds into one gap-free region, so any random point sampled from the center lands inside some cloud and decodes to something sensible — which is exactly what lets a VAE generate, where a plain autoencoder couldn't.

Wrapping Up: The Whole Picture

Let's zoom all the way back out, because you've now traveled a long road and it's worth seeing it as one connected line.

We started with the feedforward network — powerful for fixed inputs, helpless with sequences. To get memory and handle order, we added recurrence and got the RNN — which gave us memory but choked on the vanishing-gradient problem and couldn't reach far back. To protect the gradient, we built the LSTM and GRU with their gated memory highway — better, but still sequential (so slow) and still limited in range, and in translation setups still forced to cram everything through one fixed-size vector. To relieve that cram, we bolted attention onto Seq2Seq — and discovered that letting the decoder look anywhere it wanted was the real breakthrough. Then someone asked the obvious question: if attention is the good part, why keep the slow recurrence at all? And the transformer was born.

From there we cracked the transformer open: tokenization and embeddings to get text into vectors; self-attention (Q, K, V, scaled dot product, softmax) as the core; multi-head attention to capture many relationships at once; masked attention to keep generation honest; cross-attention to connect decoder to encoder; the FFN to process what attention gathered; layer norm and residual connections to keep deep stacks trainable; and positional encoding to put order back in. Stack those layers, add a linear-plus-softmax head, and you can read or write any sequence.

Then we watched the architecture evolve without ever being replaced: decoder-only models that turned next-token prediction into a universal interface; encoder-only models (BERT) that specialized into representations, search, and RAG; better positional schemes (RoPE, ALiBi); cheaper normalization (RMSNorm); better activations (SwiGLU); cheaper attention (sliding window, sparse, Flash Attention); cheaper memory (KV cache, GQA/MQA, quantization); more capacity for less compute (MoE); faster generation (speculative decoding); grounding and capability (RAG and tools); and the training pipeline that turns raw next-token prediction into a helpful assistant (pretraining → SFT → RLHF, plus LoRA for cheap specialization and Constitutional AI for scalable alignment). Notice how often the same villain (vanishing gradients) and the same hero (attention, and the idea of adding cheap corrections instead of rebuilding from scratch) kept reappearing. That's not a coincidence — it's the connective tissue of the whole field.

Finally we stepped outside the transformer to the other generative families — GANs (an adversarial game that recovers the data distribution), autoencoders (squeeze and rebuild, great for representation, useless for generation), and VAEs (clouds instead of dots, so the latent space fills in and you can generate). Each one is, again, a specific fix for a specific limitation of the thing before it.

Here's a compact comparison of the three generative families, since it's the kind of thing worth having on one screen:

Generative Model Tradeoffs: Autoencoder vs VAE vs GANAutoencoderVAEGANdots + gapssmooth cloudsnoise → dataCore ideaLatent spaceCan generate?Sample qualityTraining stab.Best atencode→decodeclouds + KLgen vs discgaps, holessmooth, packedimplicit noiseNo (gaps)YesYesrecon onlyblurrysharpstablestableunstablecompresssmooth genrealismClick any cell to read a one-line explanation.Use the buttons to highlight a model column or who can generate.Each family trades something: autoencoders give representation,VAEs give smooth generation, GANs give sharpness.
Fig 5.39 — Each family trades something: autoencoders give representation, VAEs give smooth generation, GANs give sharpness.

And that's the arc — from a network that couldn't even tell "dog bites man" from "man bites dog," all the way to models that write code, hold conversations, see images, and generate worlds. Every step was someone looking at the previous model's weakest point and asking, "what if we fixed just that?" Now you can read any of those papers and know exactly which point they're fixing.