ML Bible · Chapter 6
Vision
From pixels and tensors through convolutions and CNNs, to object detection (YOLO), segmentation (FCN, U-Net, Mask R-CNN, SAM), Vision Transformers, and vision-language models.
Everything we've built so far — the whole transformer story — was about sequences of words. Now we point the same machinery at the visual world, and a surprising amount of it carries straight over: the same vanishing-gradient villain, the same residual-connection hero, the same "tokenize everything and let attention sort it out" punchline. But vision also has its own beautiful set of ideas, built around one operation — the convolution — that was purpose-designed for images.
Here's the road we'll travel, and like last time, every step is a fix for the thing before it. We start with what an image even is to a computer. Then we ask the obvious question — "can't I just feed pixels to a regular neural network?" — watch it fail, and that failure gives us the CNN. We stack CNNs into a feature hierarchy, then put them to work: first detecting objects with YOLO, then labeling every pixel with segmentation (FCN → U-Net → Mask R-CNN), then the foundation model that segments anything (SAM). Finally we come full circle: drop the convolution entirely, chop the image into patches, and feed them to a transformer — the Vision Transformer — which leads straight to vision-language models, where images and text live in the same stream of tokens.
I'm assuming you know what a neural network is and roughly how a GPU works, and that you've seen the transformer (we'll lean on it once we reach ViTs). Everything else, we build from the ground up — and since this is the visual chapter, we'll lean hard on diagrams. Let's go.
Image Fundamentals
Before any neural network can touch an image, we have to agree on what an image is to a computer. Spoiler: it's just a big grid of numbers. Once you see that clearly, everything else clicks into place.
Pixels, channels, and tensors
A digital image is a grid of pixels. Each pixel is a tiny block of color described by numbers — usually three of them: the red, green, and blue (RGB) intensities. Each one is an integer from 0 to 255 (that's 8 bits per channel). So a single pixel is a 3-number vector like (200, 50, 80) — a pinkish red.
Stack that grid up and an image becomes a 3D tensor with three axes:
- Height (H)
- — the number of pixel rows.
- Width (W)
- — the number of pixel columns.
- Channels (C)
- — 3 for RGB, 1 for grayscale, 4 if there's an alpha (transparency) channel.
Here's the thing that trips people up, and it's worth getting straight right now: the channel dimension is conceptually way more important than just "RGB." For the input image, sure, channels are colors. But after the very first convolutional layer, channels stop meaning colors and start meaning learned features — maybe one channel lights up on horizontal edges, another on red blobs, another on diagonal stripes. By the time you're deep in a CNN, each channel stands for some abstract pattern the network taught itself to look for, and the channel count typically grows from 3 at the input to hundreds or thousands deep in. Hold onto that — it's the heart of how CNNs work.
A concrete example: a 1024×768 RGB image is a tensor of shape (3, 768, 1024) in PyTorch's convention (channels first) or (768, 1024, 3) in TensorFlow/NumPy's convention (channels last). Same data, different bookkeeping.
When you stack multiple images into a batch for training, you add a fourth axis:
(N, C, H, W) — N images, C channels, height H, width W.
That's the input tensor shape your CNN expects, and basically every other shape in this chapter eventually traces back to it.
Let's see all of that in one picture:
Check your understanding
What does the shape (32, 3, 224, 224) describe, and which number will balloon as the image moves deeper into a CNN?
Show answer ▸Hide answer ▾
It's a batch of 32 images, each with 3 channels (RGB), 224 pixels tall and 224 wide — the (N, C, H, W) layout. As the image flows deeper into a CNN, the channel count (the 3) grows — to 64, 128, 256, and beyond — because each layer adds more learned feature detectors, while the spatial H and W usually shrink. After the first layer those channels no longer mean colors; they mean learned patterns.
Normalization
Raw pixel values in [0, 255] make terrible inputs to a neural network. They're too large and not centered around zero, which makes optimization jumpy and unstable. So we always preprocess, in two steps:
1. Divide by 255 to squeeze every value into [0, 1].
2. Subtract the dataset mean and divide by the dataset standard deviation, computed per channel. This re-centers the data around zero with a consistent spread.
For models trained on ImageNet, the standard normalization uses mean = (0.485, 0.456, 0.406) and std = (0.229, 0.224, 0.225) — one number per RGB channel. The golden rule: normalize the exact same way at inference as you did at training. If you trained on normalized inputs and then feed raw pixels at test time, the model sees garbage.
Stage 1 divides by 255 into [0,1]; stage 2 subtracts the per-channel mean and divides by the per-channel std, centering the histogram on 0.
If that felt a little abstract — why these numbers matter, why zero-centering helps — hang tight. The next sections wire it all back into the network, and it should click.
Check your understanding
Why subtract the mean and divide by the standard deviation instead of just dividing by 255?
Show answer ▸Hide answer ▾
Dividing by 255 fixes the scale (values land in [0, 1]) but leaves the data sitting at large positive averages, not centered on zero. Subtracting the per-channel mean re-centers the inputs around zero, and dividing by the per-channel standard deviation gives them a consistent spread. Zero-centered, unit-ish-variance inputs make the loss surface better behaved, so optimization is more stable and can use higher learning rates. And you must use the same mean/std at inference, or the model sees a distribution it never trained on.
Convolutional Neural Networks
So: how do we actually feed images to a neural network? Your first instinct is probably the natural one — an image is just numbers, so why not flatten it and pour it into a plain fully-connected network (an MLP), like any other input? Let's try exactly that and watch it fall apart, because the way it fails tells us precisely what to build instead.
Why a plain MLP doesn't work
Take a 224×224×3 image. Flatten it into one long vector and you get 150,528 numbers. Feed that to a fully-connected network. Two problems make this a disaster.
Parameter explosion. The first layer's weight matrix would be 150,528 × hidden_dim. For even a modest hidden_dim of 1,000, that's 150 million parameters in the first layer alone. Networks that big do train, but you'd be burning your entire parameter budget just learning to look at individual pixels.
No translation invariance. A dog in the top-left corner and the same dog in the bottom-right corner would activate completely different input neurons. The network would have to learn "what a dog looks like" separately for every possible position in the image. With any finite amount of training data, that's hopeless.
The fix is a specialized architecture: the Convolutional Neural Network. CNNs are built around two facts about images that the MLP ignored:
- Locality
- — pixels near each other are correlated and form meaningful patterns; far-apart pixels mostly aren't. A nose is a local arrangement of pixels, not a relationship between opposite corners.
- Translation invariance
- — a feature (an edge, a texture, an eye) means the same thing no matter where it appears.
The convolution operation is engineered around exactly these two facts. Let's build it.
Check your understanding
The two MLP failures map onto the two facts CNNs exploit. Which fix addresses which failure?
Show answer ▸Hide answer ▾
Reusing one small filter across all positions (the locality + weight-sharing idea) fixes the parameter explosion — instead of 150M weights you have a few dozen, reused everywhere. And sliding that same filter across the whole image gives translation invariance — the detector responds to a feature identically wherever it appears, so the network doesn't have to relearn "dog" for each corner.
The convolution operation
A convolution slides a small filter (also called a kernel) across the image, computing a dot product at each position. A filter is a small tensor of weights — typically 3×3 or 5×5 in spatial size — with the same number of channels as its input. At each spatial position you run this little algorithm:
1. Multiply the filter element-wise with the patch of input it's sitting on.
2. Sum all those products.
3. Add a bias.
4. Write the result to that position in the output.
5. Slide the filter one step over and repeat, until you've covered the whole input.
The output is a 2D feature map showing how strongly that filter's pattern is detected at each location. Mathematically, for an input and filter :
Reading the symbols:
- — the value written at output position .
- — the input pixel at row , column , channel (the patch under the filter).
- — the filter weight at offset in channel .
- — a single scalar bias added at the end.
- The triple sum runs over the filter's spatial dimensions and across all input channels .
That triple sum is the whole operation. Here's the slide-and-multiply in motion:
Two quick notes for completeness. In full generality, with explicit kernel size and input channels :
where is the output, the input, the filter (kernel size with the same channel count as the input), and a single learned bias added to every output position.
And a technical aside worth knowing: what we just described is technically cross-correlation, not true convolution. A real mathematical convolution would flip the filter before applying it. In deep learning everyone calls it convolution anyway, because the filter is learned — flipping or not flipping makes no difference, the network just learns whatever weights work either way.
Check your understanding
A filter has the same number of channels as its input. For an RGB input, how many weights does one 3×3 filter have (ignoring bias)?
Show answer ▸Hide answer ▾
weights. The filter is 3×3 in space and must match the input's 3 channels in depth, so it's a (3, 3, 3) tensor. Adding the single bias makes 28 learned numbers for that one filter — and those same 28 numbers get reused at every spatial position, which is exactly where the parameter savings and translation invariance come from.
A convolution worked out from scratch
Let's make this concrete with real numbers. Take a tiny 5×5 grayscale input:
And this 3×3 filter — roughly a horizontal-edge detector:
To get the output at position (0,0) — the top-left output cell — overlay the filter on the top-left 3×3 patch:
Multiply element-wise and sum:
So output(0,0) = 5. Now slide one step right:
So output(0,1) = 10. Keep sliding across and down. A 5×5 input convolved with a 3×3 filter gives a 3×3 output by default — we'll see exactly why in a moment.
What is this filter actually doing? It returns a big positive value wherever the bottom row of the patch is brighter than the top row — that is, wherever there's a horizontal edge running from dark-on-top to light-on-bottom. The value 5 at (0,0) means a moderate horizontal edge there; the 10 at (0,1) means an even stronger one just to the right. That's what "detecting a pattern" means: the output is highest where the input matches what the filter is looking for.
Check your understanding
Why does the output value of 10 mean a "stronger edge" than the output of 5?
Show answer ▸Hide answer ▾
This filter computes (sum of bottom row) − (sum of top row) for each patch. A larger positive result means a bigger jump in brightness from the top of the patch to the bottom — i.e. a sharper dark-to-light horizontal transition. At (0,1) the bottom row (1, 7, 8) is much brighter than the top row (0, 1, 5), giving +10; at (0,0) the contrast is milder, giving +5. The filter's output magnitude directly measures how strongly the patch matches the pattern it detects.
Multiple filters and channels
One filter produces one feature map — one kind of pattern detected across the image. But the world has many kinds of patterns: horizontal edges, vertical edges, diagonals, color blobs, textures. So a conv layer uses many filters in parallel, each hunting for a different thing.
If a layer has filters and the input has channels, then:
- Each filter has shape — the same depth as the input, with spatial size .
- All filters together form a weight tensor of shape .
- The output has channels — one feature map per filter.
So for an input of shape (3, 224, 224) — RGB, 224 pixels a side — a conv layer with 64 filters of size 3×3 produces an output of shape (64, 224, 224). Each of those 64 output channels is one filter's response across the whole image.
This is the key reason CNNs scale so gracefully: the input has 3 channels (R, G, B), but after one layer you have 64 abstract feature channels. After another, 128. After another, 256. The network keeps building richer and richer descriptions of the image. And — exactly as promised earlier — after that first layer "channels" no longer mean colors. Maybe channel 7 fires on vertical edges, channel 23 on red blobs, channel 41 on diagonal textures. Deeper still and channels mean eyes, faces, wheels, paws. The network discovers all of this from data; nobody assigns the meanings.
Check your understanding
A conv layer takes a (3, 224, 224) input and has 64 filters of size 3×3. What's the shape of its weight tensor, and what's the output shape (with same padding)?
Show answer ▸Hide answer ▾
The weight tensor is (64, 3, 3, 3) — 64 filters, each matching the input's 3 channels with a 3×3 spatial size. The output is (64, 224, 224): 64 channels (one feature map per filter), with the 224×224 spatial size preserved by same padding. The output depth equals the number of filters, full stop.
Counting the parameters
Let's actually count, because the savings are staggering. Input (C_in, H, W) = (3, 224, 224), filter size 3×3, output channels K = 64.
- Per filter
- weights, plus 1 bias = 28.
- Whole layer
- parameters.
Now compare to a fully-connected layer mapping the flattened input to just 1,000 outputs: million parameters. That's an 80,000× reduction. And here's the kicker: those 1,792 parameters are reused at every one of the spatial positions. The CNN gets translation invariance for free out of this exact design — the same filter looks for the same pattern everywhere.
Check your understanding
Where does the CNN's translation invariance actually come from in this parameter-counting picture?
Show answer ▸Hide answer ▾
From weight sharing. The same small filter (e.g. 28 numbers) is applied at every spatial position rather than learning separate weights per location. Because the identical detector slides across the whole image, a pattern produces the same response wherever it sits — that's translation invariance — and it's also why the parameter count is tiny and independent of image size.
Padding, stride, and output size
Back in the worked example, a 5×5 input gave a 3×3 output. Why did it shrink? Because we only placed the filter where it fully fit inside the input — and a 3×3 filter only has 3 valid starting positions along a 5-wide axis. In a deep network that shrinking is a real problem: after a few layers your image would dwindle to nothing.
Padding fixes it. Add a border of zeros around the input before convolving. With padding , the input effectively becomes . With for a 3×3 filter, the output comes out the same spatial size as the input — this is called "same" padding, and nearly every modern CNN uses it.
Stride is how far the filter jumps between positions. Stride 1 means "slide one pixel at a time." Stride 2 means "skip every other position," which halves the output's spatial size and is a common way to downsample.
The output spatial size formula ties it all together:
where is padding, is kernel size, is stride, and is the floor (round down). Let's plug in a few cases:
- 5×5 input, 3×3 filter, no padding, stride 1: . Output 3×3. (That's our worked example.)
- 224×224 input, 3×3 filter, padding 1, stride 1: . Output 224×224 — same padding.
- 224×224 input, 3×3 filter, padding 1, stride 2: . Output 112×112 — halved.
- 224×224 input, 7×7 filter, padding 3, stride 2: . Also halved.
Check your understanding
You have a 56×56 feature map and apply a 3×3 conv with padding 1 and stride 2. What's the output spatial size?
Show answer ▸Hide answer ▾
. So 28×28 — the stride-2 step halves the spatial size, while the padding-1 keeps the arithmetic clean. This is exactly the downsampling pattern used between CNN stages.
Activation: ReLU
A conv layer is a linear function of its input. And stacking linear layers just gives you another linear function — no matter how many you pile up, the whole thing collapses into one single linear transformation. That's useless for the rich, nonlinear patterns in images. (If this argument feels familiar, it's the same reason the transformer's feed-forward block needed a nonlinearity — the villain and the fix recur across architectures.)
So after every conv layer we apply a nonlinear activation function, almost always ReLU (Rectified Linear Unit):
That's the whole thing: any negative value becomes 0, positives pass through unchanged. ReLU is dirt cheap to compute, it doesn't saturate (its gradient is 1 everywhere positive, so signal flows nicely during backprop), and empirically it trains deep networks far better than the older sigmoid and tanh. Modern variants include Leaky ReLU (a small negative slope instead of a hard zero), GELU (a smooth version used in transformers), and SiLU/Swish (smooth, used in some recent CNNs) — but for classic CNN work, plain ReLU is the default.
The standard conv-layer pattern, end to end, is Conv → BatchNorm → ReLU. That batch normalization step (introduced in 2015) normalizes activations to zero mean and unit variance across the batch, which dramatically stabilizes training and is nearly universal in modern CNNs.
Check your understanding
Why does removing the activation function defeat the purpose of stacking many conv layers?
Show answer ▸Hide answer ▾
Because a composition of linear maps is itself a single linear map. A conv layer with no nonlinearity is linear, so stacking ten of them is mathematically equivalent to one linear conv — the depth buys you zero extra expressive power. ReLU (or any nonlinearity) breaks that collapse, letting each added layer learn genuinely more complex, non-linear features. It's the same logic as the nonlinearity in a transformer's FFN.
Pooling
After several conv layers you usually want to downsample — shrink the spatial resolution to focus on what matters and cut the compute. There are two ways:
Strided convolution — use stride 2 in a conv layer to halve the spatial size while learning the downsampling (the filter weights decide what to keep).
Pooling — a fixed, non-learned downsampling. The most common is max pooling: slide a small window (typically 2×2 with stride 2) across the feature map and keep only the maximum value in each window. For a 2×2 max pool with stride 2, every 2×2 patch becomes one output value — the max of those four numbers — so both spatial dimensions halve.
Why max instead of average? Max pooling preserves the strongest response — if some pixel in the window fires hard for "horizontal edge," that signal survives. Average pooling would water it down. Max pooling also hands you a little translation invariance for free: nudge a feature by a pixel within a 2×2 window and the max is unchanged.
There's also global average pooling, used at the very end of a network: average each entire feature map down to a single scalar, collapsing a (C, H, W) tensor to a length-C vector. This replaced the heavy fully-connected layers of older CNNs and was popularized by ResNet.
Check your understanding
Why is max pooling usually preferred over average pooling inside a CNN, and what does global average pooling replace?
Show answer ▸Hide answer ▾
Max pooling keeps the strongest activation in each window, so a sharp feature response survives downsampling instead of being diluted by neighboring low values (as average pooling would do); it also gives a little translation invariance, since shifting a feature within the window doesn't change the max. Global average pooling, used at the network's end, collapses each feature map to a single scalar — replacing the bulky fully-connected layers that older CNNs used before their classifier.
The hierarchy of features
Here's the deepest idea about CNNs — the reason they work as well as they do. By stacking conv + pool blocks, the network builds a hierarchy of features of increasing complexity. Visualize what the filters in a trained CNN respond to and you see a clear progression:
- Early layers (1–2)
- — simple low-level features: edges at various orientations, color blobs, basic textures. These look almost exactly like the filters that classical computer-vision researchers hand-designed before deep learning (Gabor filters, Sobel operators). The network rediscovers them on its own.
- Middle layers (3–5)
- — combinations of those into mid-level patterns: corners, stripes, simple shapes, eye-like patterns, wheel-like patterns.
- Late layers (6+)
- — high-level concepts: whole objects, animal parts, faces, scene types. Individual deep neurons often fire on startlingly specific things — "dog face," "car wheel," "vertical text."
Nobody programs this hierarchy in. It emerges from training. The network discovers, by itself, that the way to recognize a dog is to first find edges, combine edges into textures and shapes, combine shapes into dog-parts, and combine parts into a whole dog. That compositional structure is exactly how vision researchers — and probably biological visual systems — think about the problem. And the reason deeper layers can represent bigger concepts comes down to one idea: their receptive fields are larger. Let's unpack that.
Receptive field
The receptive field of a neuron is the region of the input image that affects its value. Early-layer neurons see only a small patch — each looked at just a 3×3 region of the input. But deeper neurons combine outputs from many earlier neurons, so they end up seeing much more of the original image.
Here's how it grows for stacked 3×3 conv layers:
- After 1 layer: receptive field 3×3.
- After 2 stacked layers: each output depends on a 3×3 patch of the previous layer, each of which depended on a 3×3 patch of the input — so the field is 5×5.
- After 3 stacked: 7×7.
- After stacked 3×3 layers: .
Pooling and strided convolutions speed this up dramatically — a stride-2 layer effectively doubles the receptive field of everything after it. By the end of a deep CNN like ResNet-50, individual neurons can have receptive fields covering the entire input image. That's how deep neurons can "see" whole objects: their window is finally big enough to contain them.
Check your understanding
After 5 stacked 3×3 conv layers (stride 1, no pooling), what's the receptive field, and why does this matter for recognizing whole objects?
Show answer ▸Hide answer ▾
, so an 11×11 receptive field. It matters because a neuron can only respond to a concept that fits inside the region of the input it can "see." Early neurons (3×3) can only detect tiny local features like edges; deeper neurons, with their much larger receptive fields (accelerated further by pooling/striding), take in enough of the image to represent whole objects — which is why high-level concepts live in the late layers.
A complete CNN, layer by layer
Let's trace a full CNN — a simplified ResNet-style classifier — from input to output, watching the tensor shapes change at every step. Keep one eye on the pattern: spatial size shrinks while channel count grows.
Input: a single 224×224 RGB image, shape (3, 224, 224).
Block 1 — Stem.
- Conv 7×7, 64 filters, stride 2, padding 3 → (64, 112, 112)
- BatchNorm + ReLU
- MaxPool 3×3, stride 2, padding 1 → (64, 56, 56)
Block 2.
- Conv 3×3, 64 filters, stride 1, padding 1 → (64, 56, 56)
- Conv 3×3, 64 filters, stride 1, padding 1 → (64, 56, 56) (BatchNorm + ReLU between)
Block 3.
- Conv 3×3, 128 filters, stride 2, padding 1 → (128, 28, 28)
- Conv 3×3, 128 filters, stride 1, padding 1 → (128, 28, 28)
Block 4.
- Conv 3×3, 256 filters, stride 2, padding 1 → (256, 14, 14)
- Conv 3×3, 256 filters, stride 1, padding 1 → (256, 14, 14)
Block 5.
- Conv 3×3, 512 filters, stride 2, padding 1 → (512, 7, 7)
- Conv 3×3, 512 filters, stride 1, padding 1 → (512, 7, 7)
Head.
- Global average pooling: average each of the 512 feature maps to one scalar → (512,) vector
- Fully-connected layer, 512 → 1000 (for 1000-class ImageNet) → (1000,) logits
- Softmax → class probabilities
Notice the pattern: as we go deeper, spatial dimensions shrink (224 → 7) while channel count grows (3 → 512). The network is trading spatial resolution for semantic depth. Early on you have lots of pixels each described by 3 numbers (R, G, B). At the end you have just 7×7 = 49 spatial positions, but each is described by 512 abstract feature dimensions. Then global pooling throws away spatial info entirely and the fully-connected layer produces a class score. This is the canonical CNN recipe — every classifier from AlexNet to EfficientNet follows some variant of it.
Check your understanding
Across the CNN, spatial size goes 224 → 7 while channels go 3 → 512. In one phrase, what is the network doing, and what does global average pooling do at the end?
Show answer ▸Hide answer ▾
It's trading spatial resolution for semantic depth — converting many pixels described by 3 color numbers into a few spatial positions described by hundreds of abstract feature dimensions. Global average pooling then collapses each of the 512 feature maps to a single scalar, discarding spatial layout entirely and producing a 512-length vector that the final fully-connected layer turns into class scores.
Training a CNN, and the three big ideas
Training is exactly what you saw in earlier chapters: forward pass, compute loss, backprop, update weights. The only CNN-specific wrinkle is how gradients flow through the convolution. For each training example you (1) run the image forward to get logits, apply softmax, and compute cross-entropy loss against the true label; (2) backpropagate the gradient of the loss to every parameter — for conv layers, how each filter weight should change; and (3) update with an optimizer like SGD-with-momentum or Adam. Neat fact: the gradient of a conv layer is itself computed via a convolution (with flipped filters), which is why frameworks like PyTorch implement conv backprop efficiently on GPU.
The training tricks that matter most for CNNs: data augmentation (random crops, horizontal flips, color jitter, mixup — essential; without it CNNs overfit most datasets), batch normalization (normalize activations within each batch — stabilizes training, allows higher learning rates), learning-rate schedules (start high, decay via cosine annealing or step decay — critical for converging well), and pretraining (train on ImageNet first, then fine-tune on your task — this transfer-learning approach is overwhelmingly the default; training from scratch on a small dataset is rarely the right move).
To wrap up CNNs, here are the three ideas — beyond convolution itself — that turned them from a 1990s curiosity into the dominant vision architecture for a decade:
1. ReLU activations (AlexNet, 2012). Replaced saturating sigmoid/tanh with . Gradients flow much better; training is faster.
2. Batch normalization (2015). Normalizes activations within each mini-batch. Enables deeper networks, higher learning rates, and far less sensitivity to weight initialization.
3. Residual connections (ResNet, 2015). Skip connections that add a block's input back to its output. They let gradients flow back through arbitrarily deep networks — and this is, no exaggeration, the single most important architectural idea since convolution itself. It's the exact same idea used in transformers today. (Remember the skip path from the transformer chapter? Same hero, different architecture.)
Together with the convolution operation, these are the load-bearing ideas in every modern vision CNN. Understand them and you understand why every detection model (YOLO, R-CNN), every segmentation model (U-Net, Mask R-CNN), and even diffusion models (which run on U-Nets) share the same conceptual DNA.
Check your understanding
Residual connections are called "the same idea used in transformers." What problem do they solve, and what's the mechanism?
Show answer ▸Hide answer ▾
They solve the problem of training very deep networks, where gradients have to flow back through many layers and tend to vanish. The mechanism is a skip connection that adds a block's input directly to its output (). That creates a direct path for gradients during backprop, so they reach early layers intact, and it lets each block learn only a small change to its input rather than a full transformation. It's identical to the residual/Add step inside transformer layers.
YOLO: You Only Look Once
Funny name, huh? YOLO became one of the most popular computer-vision architectures around, and it's worth understanding why the design is so good. But first we need to know what came before it — because, just like in the transformer story, YOLO is best understood as a reaction to the slow, clunky thing it replaced.
A quick framing of the task. Classification answers "what's in this image?" with one label. Detection is harder: it answers "what objects are here, and where?" — drawing a bounding box around each object and labeling it. That "where" is the whole challenge.
Before deep learning: DPM
The dominant detection method before deep learning was DPM (Deformable Parts Model). It worked by sliding window: take a classifier for your target object, slide it across the image at evenly spaced locations and at multiple scales, and at each stop ask "is the object here?" For each window the model would extract hand-crafted features (typically HOG — Histogram of Oriented Gradients), score the window against a learned template for the object's overall shape, score against templates for the object's parts (legs, head, wheels), and combine the scores while allowing some deformation between parts.
The problems were severe: the pipeline was complex (feature extraction, root filter, part filters, deformation cost, post-processing — each piece designed or trained separately, no joint optimization), it was slow (even the fastest variant couldn't really do real-time general detection), and the hand-crafted HOG features worked for some objects like pedestrians but couldn't capture the diversity of natural images the way learned features could.
The first deep wave: R-CNN
The first deep-learning detectors — R-CNN (2014), then Fast R-CNN, then Faster R-CNN — replaced hand-crafted features with CNNs but kept a two-stage structure:
- Stage 1: Region proposals
- Generate ~2,000 candidate boxes per image that might contain objects. Original R-CNN used Selective Search (a classical, non-learned algorithm); Faster R-CNN later replaced it with a small Region Proposal Network.
- Stage 2: Classify each proposal
- Run a CNN on every candidate, classify it as an object class (or background), and refine the box coordinates.
Then, on top of that: a separate linear model to refine boxes, NMS to remove duplicates, and rescoring based on context. The numbers tell the story — original R-CNN took more than 40 seconds per image at test time; even Fast R-CNN at 0.5 frames per second was nowhere near real-time. Too many parts, each trained separately, each adding its own cost.
Both DPM and R-CNN share one deep flaw: they repurpose a classifier to do detection. They take a classifier, run it at many locations, and post-process the results. That means each component is trained and tuned separately; the classifier never sees the whole image (only local patches or proposals), so it can't reason about context; and the pipeline is slow because so many separate operations have to run.
The YOLO authors' core insight: detection should be one regression problem, optimized end-to-end on the actual goal — good bounding boxes and class predictions — not a stack of separately trained classifiers.
Check your understanding
What single design flaw do both DPM and R-CNN share, and what did YOLO propose instead?
Show answer ▸Hide answer ▾
Both repurpose a classifier for detection — they run a classifier at many locations or region proposals and stitch the results together with separate, individually-trained stages. This makes them slow and prevents reasoning over the whole image at once. YOLO reframes detection as a single end-to-end regression problem: one network looks at the entire image once and directly outputs all the boxes and class probabilities, trained jointly on the real goal.
The YOLO algorithm
The YOLO inference recipe is only three steps:
1. Resize the input image to 448×448.
2. Run a single CNN on the image.
3. Threshold the resulting detections by confidence (and apply NMS).
One image goes in, a list of detection boxes comes out — a genuine one-shot pipeline. Unlike R-CNN, which effectively looks at the image thousands of times, with YOLO you only look once. (NMS, by the way, is Non-Maximum Suppression — a post-processing step that removes duplicate boxes and keeps the best detections. We'll cover it properly below.)
Resize, one CNN pass, threshold + NMS. The whole detector is a single forward pass.
Now the architecture at a high level — it's organized around a grid:
Step 1 — divide the image into an S × S grid. For PASCAL VOC, S = 7, so the 448×448 image is conceptually broken into a 7×7 grid, each cell covering a 64×64-pixel region.
Step 2 — each cell predicts B bounding boxes plus C class probabilities. For VOC, B = 2 (each cell proposes two candidate boxes) and C = 20 (twenty object classes).
Step 3 — the "responsible" cell. If the center of an object falls inside a grid cell, that cell is responsible for detecting it. This is the critical rule: an object belongs to exactly one cell, regardless of how big it is. A dog whose center lands in cell (3, 4) is detected by cell (3, 4), even if its body sprawls across many cells.
So the model's prediction is entirely spatially organized. The output isn't just a flat list of boxes — it's a 3D tensor laid out spatially, where each (row, column) slot predicts what's centered in the corresponding region of the image.
Check your understanding
A large truck's body covers 12 of the 7×7 grid cells, but its center sits in cell (2, 5). How many cells are "responsible" for detecting it, and why does this matter?
Show answer ▸Hide answer ▾
Exactly one — cell (2, 5), the cell containing the object's center, regardless of how many cells the body spans. This "one object → one cell" rule is what makes YOLO's output a clean, spatially-organized tensor (each cell predicts what's centered there) and what largely prevents duplicate detections. It's also the source of a limitation we'll hit later: if two object centers fall in the same cell, the cell can struggle to report both.
The architecture
The network has 24 convolutional layers followed by 2 fully-connected layers (the paper's Figure 3). It's inspired by GoogLeNet but simpler — instead of inception modules, YOLO alternates 1×1 reduction layers with 3×3 convolutions.
What's a 1×1 convolution? It mixes channels at a single spatial position without looking at neighbors — it's used for dimensionality reduction. If a feature map has 512 channels and you want to cut it to 256 before an expensive 3×3 conv, a 1×1 conv is the cheapest way to do it. The 1×1 → 3×3 pattern became standard after this paper.
Reading the shape progression off the paper's Figure 3:
| Stage | Layer | Filter / config | Stride | Output shape |
|---|---|---|---|---|
| Input | — | — | — | 448 × 448 × 3 |
| 1 | Conv | 7×7×64 | 2 | 224 × 224 × 64 |
| 1 | Maxpool | 2×2 | 2 | 112 × 112 × 64 |
| 2 | Conv | 3×3×192 | 1 | 112 × 112 × 192 |
| 2 | Maxpool | 2×2 | 2 | 56 × 56 × 192 |
| 3 | Conv | 1×1×128, 3×3×256, 1×1×256, 3×3×512 | 1 | 56 × 56 × 512 |
| 3 | Maxpool | 2×2 | 2 | 28 × 28 × 512 |
| 4 | (1×1×256, 3×3×512) ×4, then 1×1×512, 3×3×1024 | — | 1 | 28 × 28 × 1024 |
| 4 | Maxpool | 2×2 | 2 | 14 × 14 × 1024 |
| 5 | (1×1×512, 3×3×1024) ×2, then 3×3×1024, 3×3×1024 stride 2 | varies | — | 7 × 7 × 1024 |
| 6 | Conv | 3×3×1024, 3×3×1024 | 1 | 7 × 7 × 1024 |
| 7 | Fully connected | — | — | 4096 |
| 8 | Fully connected | — | — | 7 × 7 × 30 |
The final output is reshaped to a 7 × 7 × 30 tensor — that's the entire detection prediction for the image. Where does 30 come from? For VOC, S = 7, B = 2, C = 20, and each cell predicts: B = 2 boxes × 5 numbers each (x, y, w, h, confidence) = 10, plus C = 20 class probabilities, for a total of 30 channels per cell. The 7 × 7 spatial layout corresponds directly to the 7×7 grid over the image, so cell (i, j) of the output describes grid cell (i, j) of the input.
The network's convolutional layers are pretrained on ImageNet classification first. (ImageNet is a massive set of over 14 million images across more than 20,000 categories.) Then those pretrained conv layers are transferred to detection. In the original paper, the authors: take the first 20 conv layers, append an average-pooling and a fully-connected layer, train on ImageNet at 224×224 for about a week, and hit 88% top-5 accuracy (comparable to GoogLeNet). Then they convert to detection by adding 4 more conv layers and 2 fully-connected layers (randomly initialized) and doubling the input resolution to 448×448, since detection needs finer detail than classification. This pretrain-then-fine-tune approach is now standard for nearly every detection model.
(Side note — Fast YOLO uses the same training setup but only 9 conv layers instead of 24, with fewer filters. It runs at 155 FPS on a Titan X GPU, trading some accuracy for speed while staying more than 2× as accurate as any prior real-time detector.)
Activation: Leaky ReLU
Every layer except the final one uses Leaky ReLU:
Standard ReLU outputs zero for any negative input, which can cause dead neurons — neurons stuck outputting zero that never recover, because zero output means zero gradient. Leaky ReLU keeps a small slope (0.1) for negative inputs, so a little gradient still flows even when the neuron isn't firing. The final layer uses a linear activation, because it predicts coordinates and probabilities that need to range over the real numbers.
Check your understanding
Why does YOLO use Leaky ReLU in its hidden layers but a linear activation in the final layer?
Show answer ▸Hide answer ▾
Leaky ReLU's small negative slope (0.1) keeps a trickle of gradient flowing through neurons that would otherwise be stuck outputting zero (the "dead neuron" problem with plain ReLU), helping the deep stack keep learning. The final layer is linear because it has to output box coordinates and confidence/probability values that span the real number line — clamping negatives (as ReLU would) would distort those regression targets.
Target outputs: what each box actually predicts
To reiterate, YOLO produces bounding boxes with coordinates x, y, w, h. Each of the B = 2 boxes per cell is described by four spatial numbers plus a confidence:
x, y — the box center, relative to the grid cell, normalized to [0, 1]. So x = 0.5, y = 0.5 puts the center smack in the middle of the cell; x = 0 is the left edge, y = 1 the bottom edge. By construction, the box center cannot leave its responsible cell.
w, h — width and height, relative to the whole image, normalized to [0, 1]. So w = 0.5 means the box spans half the image's width. Note these are not relative to the cell — they're relative to the entire image, because objects can be far larger than one cell.
This is a careful, deliberate design. It means (x, y) are tightly bound to [0, 1] and the network just learns small offsets from a known cell center, while (w, h) can express any object size from tiny to image-spanning.
The fifth number per box is the confidence C, which the paper defines as:
In words: confidence is the probability that an object exists in this box, multiplied by how good the box is — measured as the IOU (Intersection over Union) between the predicted box and the ground-truth box. So:
- No object → → confidence should be 0.
- Object present and the box matches perfectly → IOU = 1 → confidence should be 1.
- Object present but the box is sloppy → low IOU → low confidence.
A single number thus encodes both existence and accuracy — a high-confidence box is more likely both to contain an object and to localize it well. Let's pin down IOU, since it shows up everywhere in detection:
IOU = 1 means the boxes are identical; IOU = 0 means they don't overlap at all; IOU > 0.5 usually means "significant overlap, probably the same object."
Class probabilities
The remaining C = 20 numbers in each cell are conditional class probabilities:
The "conditioned on object" part is the key: these answer "given that an object is in this cell, what class is it?" They don't need to be zero when there's no object — they're only meaningful when an object is present. And importantly, YOLO predicts one class-probability vector per cell, regardless of how many boxes B that cell predicts. Both of a cell's boxes share the same class vector — which leads to a limitation we'll discuss shortly.
At inference, the network hands you confidence and class probabilities separately. To get the final class-specific score for a box, just multiply:
This single score per (box, class) pair combines the probability of that class, and how well the box fits. Now you have, for every box in every cell, a confidence score for every class — you filter out the low-confidence ones, apply NMS, and you're done.
For a 7 × 7 grid with B = 2 boxes per cell, the network outputs 7 × 7 × 2 = 98 bounding boxes per image. Compare that to R-CNN's ~2,000 region proposals! Far fewer candidates, all generated in parallel by one network. Most of those 98 will have very low confidence (no object in that cell) and get filtered out at the thresholding step. Cool, right?
Check your understanding
YOLO outputs 98 boxes per image versus R-CNN's ~2,000 proposals. Where does 98 come from, and why are most of them discarded?
Show answer ▸Hide answer ▾
: the 7×7 grid with B = 2 boxes per cell. Most cells don't contain an object center, so their boxes get a near-zero confidence () and are removed at the confidence-thresholding step (then NMS cleans up the rest). The point is that all 98 are produced in a single parallel forward pass, unlike R-CNN's ~2,000 separately-processed proposals — that's the speed win.
Non-Maximum Suppression (NMS)
The grid design already prevents most duplicate detections — most objects fall cleanly into one cell. But some objects, especially large ones or those near cell borders, get detected by multiple cells. NMS cleans these up. The algorithm, run for each class separately:
1. Take all boxes predicted for that class with class-specific confidence above some threshold.
2. Sort them by confidence, highest first.
3. Take the top box and add it to the final output.
4. Compute IOU between this top box and all remaining boxes.
5. Discard any box whose IOU with the top box exceeds a threshold (e.g. 0.5) — those are duplicates of the same object.
6. Repeat from step 3 with the next-highest remaining box, until none are left.
Keep the most confident box, suppress its high-overlap neighbors, repeat - duplicates gone.
How important is NMS to YOLO? The paper makes a subtle point: NMS adds only 2–3% mAP to YOLO — far less than it adds to R-CNN or DPM. Why? Because YOLO's grid already enforces spatial diversity (each cell can only predict objects centered there), so most duplicates never arise in the first place. R-CNN's region proposals can overlap freely, so it leans heavily on NMS. YOLO's structure has NMS partially built in via the grid; NMS just polishes the edges.
Check your understanding
NMS dramatically helps R-CNN but only adds 2–3% mAP to YOLO. Why the difference?
Show answer ▸Hide answer ▾
Because YOLO's grid already prevents most duplicates: an object is assigned to the single cell holding its center, so the architecture enforces spatial diversity by design. R-CNN's ~2,000 region proposals can overlap freely and pile multiple boxes on the same object, so it depends on NMS to clean up the mess. In YOLO, NMS only mops up the few duplicates from large or border-straddling objects — the grid did most of the deduplication for free.
The loss function: how YOLO learns
The total loss is — a classification part and a localization part — and the trick that makes the whole thing trainable is that everything is squared error, which turns detection into a regression problem. The full loss has five terms:
- Term 1 — center-coordinate loss
- Squared error on the (x, y) predictions. Counted only for the box responsible for an object (denoted ). Weighted by .
- Term 2 — size loss
- Squared error on and (note the square roots). Same condition and weighting as Term 1. The square root is the fix for small-box sensitivity — more on that in a second.
- Term 3 — confidence loss for object cells
- The predicted confidence should match the IOU of the predicted box with the ground truth. Squared error, weight 1.
- Term 4 — confidence loss for no-object cells
- Confidence should be 0 here. Weighted down by so the many empty cells don't dominate the loss.
- Term 5 — classification loss
- Sum of squared errors over the class probabilities, counted only for cells that contain an object ( — per cell, not per box, because classification is per-cell).
A couple of those weights deserve a word. up-weights localization so getting boxes right matters more than the raw class terms. down-weights the flood of empty cells so they don't drown out the signal from the few cells that actually contain objects. And the trick: a small absolute error in a small box hurts IOU far more than the same error in a large box; taking the square root compresses large values so that equal relative errors are penalized more equally across box sizes.
What "responsible" means
Each cell predicts B = 2 boxes, but at training time we want exactly one of them responsible for any given object. The rule (paper section 2.2): assign the predictor whose box currently has the highest IOU with the ground truth. So if cell (3, 4) holds a dog, both its box predictors make predictions, and whichever has the higher IOU with the true dog box is declared responsible. Only that responsible predictor incurs the coordinate and box-confidence losses for the object; the other gets a no-object signal.
This produces specialization: the two predictors in each cell learn to handle different kinds of boxes — one might gravitate toward tall, narrow boxes (people), the other toward wide, short ones (cars). The authors note this "improves overall recall," since different shapes get handled by different predictors.
Training hyperparameters, from the paper, for completeness: 135 epochs on PASCAL VOC 2007 + 2012; batch size 64, momentum 0.9, weight decay 0.0005; a learning-rate schedule that warms up from to over the first epochs (jumping straight to would make training diverge), holds for 75 epochs, drops to for 30, then for the final 30; dropout 0.5 after the first FC layer; and data augmentation with random scaling/translation up to 20% of image size plus random exposure and saturation jitter up to 1.5× in HSV space. The warm-up-then-decay schedule is a familiar pattern in modern training, and the augmentation — modest by today's standards — is still critical to avoid overfitting VOC's smallish training set.
Check your understanding
Why does YOLO's loss use and instead of and , and why is ?
Show answer ▸Hide answer ▾
The square root makes the size penalty fairer across object scales: a fixed pixel error in a small box damages IOU much more than the same error in a large box, and taking the square root compresses big values so equal relative errors are weighted more equally. down-weights the confidence loss from the many empty cells — most of the 49 cells contain no object, so without this their "confidence should be 0" signal would swamp the gradient from the few cells that actually matter.
YOLO's limitations
The paper is refreshingly honest about its model's weaknesses (section 2.4). Worth knowing, because every later YOLO version is largely about fixing these.
Strong spatial constraints. Each cell predicts only B = 2 boxes and only one class. So a cell containing two objects of different classes is in trouble — it can predict two boxes but only one class vector, so a person standing next to a bird in the same cell forces a choice. And groups of small objects (a flock of birds, many centers in one cell) simply can't all be predicted. This is structural; v2 onward loosened it dramatically with anchor boxes and higher S.
Poor generalization to unusual aspect ratios. YOLO learns box shapes from data, so it struggles with objects in configurations or aspect ratios it didn't see in training. There's no principled handling of out-of-distribution box shapes — it just relies on having seen enough examples.
Coarse features. The 448×448 input is downsampled hard before prediction, so each output position corresponds to a 64×64 image patch (448/7). Small objects may not have enough features left at that resolution to be detected well.
The loss doesn't match the goal. Sum-squared error treats localization and classification errors as comparable (even the trick only partly fixes this), and treats errors in large and small boxes similarly (the trick helps but doesn't fully fix it). The actual goal is mean Average Precision (mAP), a complex ranking metric; squared error is a convenient proxy that usually does the right thing but not always.
Check your understanding
Why does the original YOLO struggle with a flock of small birds clustered together?
Show answer ▸Hide answer ▾
Two structural reasons. First, each grid cell predicts at most B = 2 boxes and only one shared class vector, so if many bird centers fall into the same cell, the cell literally cannot output a box for all of them. Second, YOLO's coarse features mean each cell corresponds to a 64×64 image patch, so small objects have very little feature detail by the time predictions are made. Later versions raise the grid resolution S and add anchor boxes to ease both problems.
Segmentation Models
Detection draws a box around each object. Segmentation goes finer: it figures out what each pixel belongs to. The output of a segmentation model is a mask the same spatial size as the input — for a 512×512 input, that's a 512×512 grid of labels, 262,144 predictions per image. Three flavors are worth knowing:
- Semantic segmentation
- — every pixel gets a class label, but no instance distinction. Three dogs all get labeled "dog" and merge into one blob of dog-pixels.
- Instance segmentation
- — every pixel gets a class label and an instance ID. Three dogs become three separate masks: "dog #1," "dog #2," "dog #3." Great for counting.
- Panoptic segmentation
- — the two combined. "Things" (countable: cars, people, dogs) get instance IDs; "stuff" (uncountable: sky, road, grass) gets only a class label. The most complete description of an image.
Why segmentation is architecturally harder
There's a tension here that classification simply doesn't have. A classification CNN aggressively downsamples — from 224×224 down to a 7×7 final feature map, a 32× reduction. That downsampling is good for classification: the deepest features have huge receptive fields and capture the whole object, and you don't need spatial precision because you only output one label.
Segmentation needs both deep semantics and full resolution at once. You need the deepest features to know what you're looking at, but you also need to output a prediction at every original pixel. A 7×7 mask for a 224×224 input is useless — that's one label per 32×32 patch. The classical CNN treats spatial resolution and semantic depth as a tradeoff: you get one or the other, not both. To segment, you need a way to recover the lost resolution while keeping the deep semantic information. Every modern segmentation architecture exists to solve exactly this.
The Fully Convolutional Network (FCN)
The first deep approach to segmentation was the Fully Convolutional Network. The idea: take a classification CNN and turn it into a segmentation model by removing the fully-connected layers and replacing them with convolutional ones that output a per-pixel prediction.
The key realization (from the FCN paper) is that a fully-connected layer is mathematically the same as a convolution with a kernel the size of its input — same weights, same arithmetic. So you can "convolutionalize" the classifier. Now its output goes from one vector per image to a grid of class scores — a small spatial map. Run a 224×224 image through a converted VGG-16 and you get roughly a 7×7 map where each spatial cell holds 21 class scores (for PASCAL VOC's 21 classes).
But that's still a 7×7 grid, and you need 224×224 for a per-pixel mask. The fix is upsampling: blow the small score grid back up to the input size with a learnable upsampling operation (originally transposed convolution, which we'll detail soon). The full FCN pipeline: run the input through a CNN backbone to get a 7×7×21 score grid, upsample 32× back to 224×224×21, then take the argmax along the class dimension to get a 224×224 mask.
This worked — it was the first end-to-end neural network for semantic segmentation, and it crushed the hand-engineered approaches. But it had a problem: upsampling from 7×7 straight to 224×224 produces blurry, low-detail masks. The model knows what's in the image but loses precision about where the boundaries are. The reason is fundamental — by the time you reach the 7×7 map, the spatial detail is gone. You can upsample the resolution arithmetically, but you can't conjure back information that was thrown away. You'd need to combine the deep features (which know what) with shallow features (which still hold the spatial where). FCN partially patched this with skip connections from earlier layers (FCN-16s, FCN-8s) — small but real gains — but the architecture wasn't designed around the idea; it was bolted on. The model that was designed around it: U-Net.
Check your understanding
FCN can upsample its 7×7 score grid back to 224×224, so why are its masks still blurry?
Show answer ▸Hide answer ▾
Because upsampling restores resolution (number of pixels) but not information. By the time the network reaches the 7×7 feature map, the precise spatial detail — exactly where edges and boundaries sit — has been discarded through downsampling. Arithmetic upsampling can't invent that lost detail back. To get sharp masks you have to combine the deep, low-resolution "what" features with shallow, high-resolution "where" features — which is precisely what U-Net's skip connections are built to do.
U-Net
U-Net (Ronneberger, Fischer, Brox, 2015) came out the same year as FCN, originally for medical image segmentation, and it's now the most influential segmentation architecture there is — the foundation of countless models, including modern diffusion models. The reason it took over: it solved the resolution-versus-semantics problem with a clean, symmetric design.
As the name says, the architecture is shaped like a U. The left side is the encoder, the bottom is the deepest, most compressed point, the right side is the decoder, and arcing across the U are the skip connections — the special ingredient.
- Left side — Encoder
- A standard CNN: conv blocks alternating with downsampling. Spatial size halves at each step (224 → 112 → 56 → 28 → 14); channel count doubles (64 → 128 → 256 → 512 → 1024). This builds up semantic understanding.
- Bottom — the deepest point
- The most abstract representation: smallest spatial size, most channels. Here the network "knows" what's in the image.
- Right side — Decoder
- A mirror of the encoder: conv blocks alternating with upsampling. Spatial size doubles each step (14 → 28 → 56 → 112 → 224); channel count halves. This recovers spatial resolution.
- Skip connections
- At each scale level, features from the corresponding encoder layer are concatenated with the upsampled features in the decoder. This is what makes U-Net special.
Why do the skip connections matter so much? Walk through the problem. The encoder loses spatial detail as it downsamples — each pooling step throws away exact pixel positions. By the bottom of the U, you have rich semantics ("there's a dog here") but no precise edges ("the outline is at exactly these pixels"). The decoder upsamples back to full resolution but, working only from the coarse bottom, lacks that original detail. The skip connections deliver the original spatial detail directly across: when the decoder is at the 64×64 stage, it receives the encoder's 64×64 features — features that never went through the downsample-and-upsample cycle and still know exactly where the edges are. Mathematically, at each decoder level:
The concatenation stacks the channels — if the upsampled features have 256 channels and the encoder features have 256, you get 512 after concatenation — and the following conv layer learns to merge the two sources: deep semantics from the upsampled path, spatial detail from the skip path. This is the architectural realization of "both at the same time": semantics flow up through the bottom and decoder, spatial detail flows across through the skips, and the decoder's conv layers figure out how to combine them.
Each U-Net block, broken down
Encoder block: Conv 3×3 → BatchNorm → ReLU, then Conv 3×3 → BatchNorm → ReLU, then MaxPool 2×2 (downsample to the next level).
Decoder block: Upsample 2× (transposed conv, or interpolation + conv), then Concatenate with the encoder skip features, then Conv 3×3 → BatchNorm → ReLU, then Conv 3×3 → BatchNorm → ReLU.
(The original U-Net used Conv → ReLU without BatchNorm, since BN was published the same year. Modern U-Nets always include BatchNorm or GroupNorm.)
Upsampling: how to double the spatial size
The decoder needs to double spatial resolution at each step. Two main ways:
Transposed convolution (also called "deconv"). A learned upsampling — mathematically it's the gradient operation of a strided convolution. Each input pixel "spreads out" into a small patch of the output, with learnable weights deciding the spread pattern. Pro: learnable, adapts to data. Con: can produce checkerboard artifacts if the stride and kernel size are chosen badly.
Interpolation + convolution. Use a fixed upsampling (nearest-neighbor or bilinear) to double the size, then apply a regular conv to refine. Pro: no checkerboard artifacts, simpler. Con: slightly less expressive. Most modern segmentation models prefer this.
The output layer
The decoder ends at the original input resolution with some feature channels (typically 64 at the first level). The final operation is a 1×1 convolution mapping those features to C output channels, one per class — output shape (H, W, C). A 1×1 conv just remixes channels at each pixel without combining neighbors, so it maps each pixel's feature vector to a per-pixel class-score vector (the per-pixel logit map). Apply softmax along the channel dimension to get per-pixel class probabilities:
where is the logit for class at pixel , and the denominator sums over all classes . To produce the final mask, take the argmax over channels at each pixel.
Check your understanding
In U-Net, what concrete information do the skip connections carry across the U, and what operation merges them with the decoder's features?
Show answer ▸Hide answer ▾
They carry high-resolution spatial detail — the encoder's same-resolution features that never went through the downsample-then-upsample round trip, so they still know exactly where edges and boundaries are. They're merged by concatenation (stacking channels: upsampled + encoder features) followed by a conv layer that learns to combine the deep "what" semantics from the decoder path with the sharp "where" detail from the skip. That fusion is why U-Net masks have crisp boundaries where FCN's are blurry.
Training U-Net
Training a U-Net is just like training a classifier, but per-pixel. The loss is per-pixel cross-entropy — for each pixel, the cross-entropy between the predicted class distribution and the true label:
where runs over pixels, over classes, is the one-hot true label, and is the predicted probability. For binary segmentation this becomes binary cross-entropy.
The big practical wrinkle is class imbalance. In medical imaging especially, you often have a tiny object (a tumor a few hundred pixels wide) in a huge image. A lazy model that predicts "background" everywhere scores 99% pixel accuracy and is completely useless. Common fixes:
- Weighted cross-entropy
- — multiply each pixel's loss by a class-dependent weight inversely proportional to class frequency, so rare classes count more.
- Dice loss
- — directly optimizes the Dice coefficient (closely related to IoU):
where is the predicted probability and the true label at pixel . Because this is a ratio of overlap to combined area, the size of the background doesn't drown it out — it intrinsically handles imbalance. Often combined: total loss = Cross-Entropy + Dice.
- Focal loss
- — down-weights easy pixels (where the model is already confident) and up-weights hard ones. Useful for very rare classes.
- Data augmentation
- — critical, especially in medical settings with small datasets: random rotations, flips, and elastic deformations (very useful for biological images, since cell shapes vary continuously), brightness shifts. The original U-Net paper leaned heavily on elastic deformations.
Why U-Net won
A few reasons it's still the default segmentation architecture a decade on: it works with limited data (originally trained on ~30 labeled cell images — sample-efficient thanks to strong inductive biases: locality, hierarchy, skip connections); it's architecturally simple (conv blocks, max pools, transposed convs, skips — easy to implement and modify); it produces sharp masks (skip connections give precise boundaries, critical for medical imaging); it's modality-agnostic (2D images, 3D volumes via 3D U-Net, medical scans, satellite imagery, microscopy); and it's the backbone of diffusion models — the same U-Net that segments cells is the architecture inside Stable Diffusion. Every modern image generator runs on a U-Net. Its longevity is striking: it's older than the transformer, and still everywhere.
Check your understanding
A tumor-segmentation model reports 99% pixel accuracy but finds no tumors. What went wrong, and which loss would help?
Show answer ▸Hide answer ▾
Class imbalance fooled the accuracy metric. The tumor is a tiny fraction of pixels, so a model that predicts "background" everywhere is 99% pixel-accurate yet useless. Dice loss (or focal loss, or class-weighted cross-entropy) fixes this: Dice measures the overlap between predicted and true tumor regions as a ratio, so it isn't dominated by the vast background, and focal loss down-weights the easy background pixels so the rare tumor pixels actually drive learning.
Mask R-CNN: instance segmentation
U-Net gives you semantic segmentation — every pixel labeled by class. But how do you tell three dogs apart? For instance segmentation, the dominant architecture is Mask R-CNN (He, Gkioxari, Dollár, Girshick, 2017). It extends Faster R-CNN (a two-stage detector) with a mask-predicting branch.
The core stance: instance segmentation is detection plus segmentation. The high-level steps: (1) detect each instance and produce a bounding box around it; (2) for each detected box, produce a binary mask of that single object inside the box. This decomposition is elegant — the detector handles "different instances are different objects," and the mask head handles "which exact pixels belong to this instance." Each is a well-studied subproblem, and combining them yields instance segmentation.
So Mask R-CNN's output for each detected region has three branches: a class label (one of C classes or background), a bounding-box refinement (fine-tuned coordinates), and a binary mask (a small mask within the box, one channel per class). The first two come straight from Faster R-CNN; the third is the new addition — a small fully-convolutional network that takes the region's feature map and outputs an m × m binary mask (typically m = 28).
Mask R-CNN keeps the class and box heads and adds a parallel mask head; per-region masks keep instances separate, a semantic mask merges them.
RoI Align — the key innovation
This is Mask R-CNN's most important technical contribution. A region proposal might land at fractional coordinates like (137.3, 248.7, 282.5, 451.1). To extract a fixed-size feature map for that region (say 7×7), you have to map those real-valued coordinates onto the discrete grid of the CNN's feature map.
RoI Pool (the old way, from Fast/Faster R-CNN): round the box coordinates to integer pixels, divide the rounded box into a 7×7 grid of sub-regions, and max-pool each. This works for classification — small misalignments don't matter when you only predict one label per region. But for mask prediction every pixel matters, and those rounding errors compound, leaving the final mask misaligned by a few pixels.
RoI Align (the Mask R-CNN way): don't round. Use bilinear interpolation to sample feature values at exact real-valued coordinates within the box. The 7×7 output grid has its corners at exactly the right floating-point positions, and each output cell pools values interpolated from neighboring feature-map cells. The math, for a sample point at floating-point , looks at the 4 nearest integer feature positions and takes a distance-weighted sum:
where are the feature values at the four surrounding integer positions and the terms are the bilinear weights (closer positions count more). That's just standard bilinear interpolation — the insight is using it instead of rounding. The result is sub-pixel-accurate feature extraction, and mask quality jumps: the paper reported a ~10% improvement in mask average precision from this single change.
The mask head and loss
Inside each detected region, the mask head is a small fully-convolutional network: take the RoI-aligned feature map (14×14 or 7×7), apply a few conv layers, upsample with a transposed conv to 28×28, then a 1×1 conv to produce C output channels — one per class, each a sigmoid binary mask. Critically, the model outputs K binary masks per region, one for each possible class, and at inference you simply take the mask for the class the classification head predicted. This decouples classification from mask prediction: the mask head doesn't have to decide "is this a dog or a cat," it just draws the right pixels.
The total loss combines all three branches:
The mask loss is per-pixel binary cross-entropy, applied only to the channel for the ground-truth class:
where is the number of mask pixels (784 for a 28×28 mask), is the true 0/1 mask value, and the predicted probability. Notice: sigmoid per pixel, not softmax. Each pixel independently answers "am I part of this object or not," with no competition between classes — which is exactly what allows the decoupling.
That decoupling is one of Mask R-CNN's most elegant ideas. In semantic segmentation (U-Net), you ask each pixel "which of the C classes are you?" — softmax forces classes to compete. In Mask R-CNN you ask each pixel "are you part of this object?" — binary, no competition. So the mask head's job is much simpler: it doesn't have to learn what a dog looks like versus a cat and distinguish them at the pixel level; it just learns to draw the boundary of whatever object is in this region, while the class head answers "what" independently. The result: cleaner gradient signal, better masks, lower data requirements.
Check your understanding
Mask R-CNN uses a sigmoid per pixel (one binary mask per class) rather than a softmax over classes like U-Net. Why is that "decoupling" helpful?
Show answer ▸Hide answer ▾
Because it separates "what is this object?" (handled by the dedicated classification branch) from "which pixels belong to it?" (handled by the mask head). With a per-class sigmoid, each pixel just answers a binary "am I part of this object?" with no competition between classes, so the mask head only has to learn to trace boundaries — not to distinguish dog-from-cat at the pixel level. That simpler, decoupled task gives a cleaner gradient, sharper masks, and lower data needs than forcing a softmax to do classification and segmentation at once.
SAM: The Segment Anything Model
In 2023, Meta released the Segment Anything Model (SAM), and segmentation got its foundation model. Before SAM, every segmentation task needed its own model trained on its own labeled dataset — one for tumors, one for roads, one for satellite imagery. After SAM, a single model could segment essentially anything in any image with a click, a box, or a rough mask — including objects it had never been explicitly trained on. It was trained on 11 million images and 1.1 billion masks (over 400× larger than the previous biggest segmentation dataset), and its architecture has three parts: a heavy image encoder, a lightweight prompt encoder, and a small mask decoder.
The big idea: promptable segmentation
The single most important idea in SAM is the shift from task-specific to promptable segmentation. A traditional model is trained for a fixed output — a U-Net trained on medical scans outputs tumor masks, a model trained on COCO outputs masks of 80 categories. The set of possible outputs is locked at training time. SAM flips this: it's trained to take a prompt — a point click, a bounding box, or a rough mask — and produce the segmentation that the prompt indicates. Whatever you point at, SAM segments. Which means you don't need labeled data for your specific task (SAM works zero-shot), you can segment things the model never saw in training, and one model serves countless downstream tasks just by changing the prompt.
If that sounds familiar, it should: it's the exact same conceptual move that GPT made in language. Instead of training a separate model for translation, summarization, and classification, you train one model that responds to prompts and let the prompt encode the task. SAM is the segmentation version of that shift.
Architecture: a deliberate asymmetry
SAM has three components, designed around one constraint: interactive use must feel almost instant. When you click, the mask should appear in milliseconds.
- Image encoder
- — a heavy Vision Transformer (ViT-H, 636 million parameters) that turns the image into a dense embedding. Runs once per image.
- Prompt encoder
- — a lightweight network that converts user prompts into embeddings. Runs once per prompt.
- Mask decoder
- — a small transformer-based module that combines the image and prompt embeddings to produce masks. Runs in milliseconds.
The asymmetry is the entire point. Encoding the image is expensive (the ViT-H takes hundreds of milliseconds on a modern GPU), but decoding from a prompt is cheap. So you encode an image once, then click on it many times, getting a fresh mask in real time for each click without ever recomputing the image embedding.
The heavy image encoder runs once per image; the light prompt encoder and mask decoder run per prompt, reusing the cached embedding.
The image encoder
The image encoder is SAM's eyes. It turns a 1024×1024 RGB image into a 64×64 grid of 256-dimensional feature vectors. It's a Vision Transformer — ViT-H in the largest variant — pretrained with Masked Autoencoding (MAE) before being adapted for segmentation. (We'll cover ViTs in full in the next chapter; here's just enough to follow SAM.)
Why a ViT and not a CNN? CNNs are excellent at local features through convolution, but they struggle with long-range dependencies because receptive fields grow only slowly with depth. Segmentation often needs global reasoning — picture segmenting a person partially hidden behind a tree. A CNN might segment the visible body parts separately, because the disconnected regions can't "talk" until very deep layers. A ViT's self-attention gives every patch immediate access to every other patch in a single layer — exactly what global reasoning needs.
Patch tokenization
The first step turns the 2D image into discrete tokens the transformer can process. SAM uses 16×16 non-overlapping patches. Operationally that's a single convolution with kernel size 16 and stride 16 — each output position corresponds to one 16×16 input patch. The shape transformation:
- Input: (3, 1024, 1024) — RGB pixels.
- After patch embedding: (1280, 64, 64) — since 1024/16 = 64 patches per side, each represented by a 1280-dim vector (for ViT-H).
- Rearranged for the transformer: a sequence of 64×64 = 4096 patch tokens.
So we go from 1024×1024 raw pixels to a 64×64 grid of patch tokens, each a rich learned feature vector — a 256× reduction in spatial resolution, packed into much richer per-location features.
Positional encodings
Patches alone have no order — the transformer can't tell where each came from. SAM adds learnable absolute positional embeddings: each of the 64×64 = 4096 patch positions gets its own learnable vector, added to the patch embedding. Unlike the sinusoidal encoding of the original transformer, these are learned — the network discovers whatever positional representation works best. (And unlike relative encodings, which we'll see in a moment, absolute embeddings tell each patch its own coordinate, not its relationship to others.)
Windowed and global attention
Here SAM does something clever. A standard ViT does full self-attention at every layer — every patch attends to every other. For 4096 tokens (the flattened 64×64 grid), that's million attention operations per head per layer, multiplied across many heads and layers. Expensive. So SAM uses mostly windowed attention with periodic global attention:
- Most layers use windowed self-attention with window size 14. Attention happens only within each 14×14 window, not across the whole image. For a 64×64 feature map that's about windows, each doing attention over only 196 tokens instead of 4096. Complexity drops from to roughly — a major saving.
- A few designated layers (typically every 8th) use full global self-attention, letting information mix across the whole image.
This is the same idea as Swin Transformer's hierarchical attention (which we'll meet later): cheap windowed layers capture local patterns, while rare, expensive global layers mix information globally. Far more efficient than full attention everywhere, with no real loss of global reasoning.
Cheap local windows most layers, occasional global mixing — global reasoning without paying O(N^2) everywhere.
Relative positional embeddings inside attention
On top of the absolute positions, SAM adds relative positional embeddings inside each attention block. The framing is clean: absolute positions tell each patch "where I am," while relative positions tell pairs of patches "how I relate to you spatially." Both are useful — absolute lets the network reason about specific locations ("top-right corner"), relative lets it reason about relationships ("these two patches are vertically adjacent"). For segmentation, relative positions are arguably more useful, since what matters is which patches belong to the same object — a relationship, not an absolute spot.
Recall the standard attention computation:
SAM adds position-dependent biases to the scores before the softmax:
where and are the vertical and horizontal distances between query position and key position , and are learned biases for those distances. The naive approach would learn a separate bias for every pair — for a 64×64 map that's parameters per head. Expensive. SAM instead decomposes the bias into separate height and width components:
which needs only parameters per head — a 98% reduction in positional parameters. The assumption is that horizontal and vertical relationships can be encoded independently, which is reasonable for natural images. Intuitively, the model can learn things like "patches in the same row are highly related" (small horizontal bias at ) and "patches close vertically are more related than distant ones" (vertical bias decaying with ); the two combine into a position-aware attention adjustment.
The neck
After all the transformer blocks, the output is (64, 64, 1280) for ViT-H. But SAM's mask decoder expects exactly 256 channels — a standardized interface independent of which ViT size you use. The neck is a small two-layer convolutional projection that handles this: a 1×1 conv reduces channels 1280 → 256, then a 3×3 conv with LayerNorm does light spatial refinement at 256. After the neck, every SAM variant (Base with ViT-B at 768, Large with ViT-L at 1024, Huge with ViT-H at 1280) produces the same output shape: (256, 64, 64). This decouples the rest of the model from the encoder choice.
MAE pretraining
Before being trained on segmentation, the ViT-H encoder was pretrained with Masked Autoencoding (MAE): randomly mask out 75% of the input patches, then train an encoder-decoder to reconstruct the missing patches from the remaining 25%. Afterward the decoder is thrown away and only the encoder is kept. The encoder has learned to extract rich, generic visual features — capturing both local detail (to reconstruct fine textures) and global context (to figure out what should be where). This is a powerful initialization: the MAE-trained encoder already "knows how to see" before SAM training begins; it just needs to learn what to extract for promptable segmentation. The result is far stronger than training from scratch, especially given how data-hungry ViTs are.
Putting the encoder all together, the shape trail is: input (3, 1024, 1024) → patch embedding (1280, 64, 64) → add absolute positional embeddings → transformer blocks (mostly windowed, a few global, with relative position biases), still (1280, 64, 64) → neck (channel reduction + spatial refinement) → (256, 64, 64). That final 64×64 grid of 256-dim features is what the rest of the model sees.
Check your understanding
SAM's encoder is heavy (ViT-H, runs in hundreds of ms) but the model still feels instant when you click. How is that possible?
Show answer ▸Hide answer ▾
Because of the cost asymmetry between the components. The expensive image encoder runs once per image and produces a (256, 64, 64) embedding that's cached. Every subsequent click only runs the lightweight prompt encoder and the small mask decoder, which together take milliseconds — they reuse the already-computed image embedding rather than re-encoding the image. So you pay the heavy cost once and then get real-time masks for as many prompts as you want.
The prompt encoder
The prompt encoder converts user inputs into vector embeddings the mask decoder can attend to. SAM accepts three kinds of prompts: points (clicks, each labeled foreground or background), boxes (a rectangular region of interest), and masks (a rough input mask, often from a previous SAM output, to refine). Points and boxes are sparse prompts — small, encoded as a few vectors. Masks are dense prompts — they have spatial structure and get encoded as a feature map. That sparse-vs-dense split matters, and we'll see why at the end.
Point prompts
A point prompt is a coordinate plus a label: 1 for foreground ("part of the object"), 0 for background ("not part of it"). You click on what you want (positive points) and optionally click things to exclude (negative points). Encoding happens in three steps:
Step 1 — coordinate normalization. Shift the pixel coordinates by 0.5 to align with pixel centers (avoiding a bias toward the top-left corner), then normalize to [0, 1] by dividing by the image size.
Step 2 — Fourier positional encoding. SAM doesn't embed directly — that would be a 2-dimensional representation, far too small. Instead it uses Random Fourier Features to lift the coordinate into a high-dimensional vector. SAM has a fixed Gaussian random matrix generated at initialization (with , half of 256). For a normalized coordinate , the encoding is:
The symbols: is the fixed random Gaussian projection matrix; produces a -dimensional vector; applying and and concatenating gives a -dimensional vector . This is essentially the same idea as sinusoidal positional encoding in transformers, generalized to continuous 2D coordinates — the random matrix gives a projection that captures both fine and coarse spatial patterns at many frequencies. Nearby points get similar encodings (smoothness); distant points get very different ones (uniqueness).
Step 3 — add a label embedding. The Fourier encoding is purely positional. To inject the label, SAM adds a learned label-specific vector: a foreground_embedding for positive points, a background_embedding for negative points, or a no_point_embedding for padding when no point is given. So each point becomes a 256-dim vector encoding both its position and its semantic role.
Box prompts
A box is four numbers — top-left and bottom-right corners. SAM handles it elegantly: treat the box as two corner points and reuse the point machinery. Each corner is Fourier-encoded just like a point, then gets a learned corner-specific embedding added — a top_left_corner_embedding for one and a bottom_right_corner_embedding for the other. So SAM has four learned point-type embeddings in total: foreground, background, top-left corner, bottom-right corner. The model learns to read them differently — a top-left corner signals "the object's upper-left bound is here," while a foreground click signals "this exact spot is inside the object." A box thus produces exactly 2 sparse vectors (one per corner), while a single point produces 1; multiple prompts just stack into a longer sequence.
Mask prompts
A mask prompt is fundamentally different — it's a 2D image at near-full resolution (256×256), not a few sparse points, carrying dense pixel-level guidance. The challenge is to fold that dense information into the 64×64 feature grid that matches the image embedding. SAM does it with a small convolutional downsampling network: start with the (1, 256, 256) input mask, apply a 2×2 stride-2 conv → (mask_chans/4, 128, 128), another 2×2 stride-2 conv → (mask_chans, 64, 64), then a 1×1 conv → (256, 64, 64). The output matches the image embedding's shape, so SAM can add the mask embedding directly to the image embedding before the decoder runs — a dense prompt modifies the image features rather than entering through attention. When no mask is given, SAM uses a learned no_mask_embedding broadcast across the (256, 64, 64) grid, keeping shapes consistent.
Sparse vs dense, and why the split matters
So the prompt encoder produces two kinds of output: sparse embeddings (a few 256-dim vectors — one per point, two per box — that get concatenated as tokens and fed into the decoder's attention) and dense embeddings (a (256, 64, 64) tensor added directly to the image embedding). The distinction matches each prompt's spatial nature: points and boxes are inherently local ("this spot is special"), so they enter as attention tokens; masks are inherently global ("this whole region is special"), so they modify the image features wholesale. Encoding each appropriately lets the mask decoder use them naturally.
The mask decoder
The mask decoder is the heart of SAM, where most of the cleverness lives. It takes the image embedding and the prompt embeddings and produces high-quality masks in milliseconds. Four key ideas: learnable output tokens, two-way attention, hypernetwork-based mask generation, and multi-mask output for ambiguity.
The output token system
A naive decoder would directly output a 256×256 mask map — but that scales badly and doesn't generalize well to multiple mask hypotheses. SAM's approach is more elegant: introduce learnable output tokens that act as queries summarizing what mask the decoder will produce. Through training, these tokens learn to represent different mask interpretations and quality scores. SAM uses two kinds:
- IoU token
- — one token whose final value predicts the quality (estimated IoU) of each output mask.
- Mask tokens
- — four tokens, each producing one candidate mask (we'll see why four shortly).
So the decoder has 5 output tokens total — learnable 256-dim embeddings, initialized randomly and trained. They get concatenated with the sparse prompt embeddings to form the decoder's input sequence. And remember, before the decoder runs, the dense mask embedding has already been added to the image embedding. So the decoder sees a sequence of tokens — [IoU token, 4 mask tokens, sparse prompt embeddings], typically 5–10 tokens — plus an image embedding of shape (256, 64, 64), flattened during attention into 4096 image tokens.
Two-way attention
This is the decoder's most important innovation. A standard transformer decoder uses one-way attention: the decoder's queries attend to the encoder's keys, but the encoder doesn't attend back. SAM uses two-way attention — at each decoder layer, both the prompt tokens and the image features get updated based on each other. Why? Because segmentation needs mutual understanding: the prompts need image context (a point click is just a coordinate; to make a meaningful mask it needs to "see" what visual content is at that location), and the image features need prompt context (the features should highlight what the user is asking about — different prompts should activate different visual features).
The two-way attention block has four steps, in order:
1. Self-attention among the prompt and output tokens — the tokens "talk to each other" (the IoU token learns what the mask tokens predict, the mask tokens coordinate to produce different hypotheses, prompt tokens combine information).
2. Cross-attention: tokens attend to image features — each token is a query, the flattened image features are keys/values. This is where prompts gather visual information from the image.
3. MLP on tokens — a standard feed-forward block refines each token.
4. Cross-attention: image features attend to tokens — now the image features are queries and the tokens are keys/values. The "reverse" direction, where image features get updated based on the prompts.
After the block, both the tokens and the image features are updated and informed by each other. SAM stacks two such blocks back to back.
Step through the four sub-steps. Toggle two-way vs. a standard one-way decoder, where the image never attends back to the tokens.
Hypernetwork-based mask generation
After two-way attention finishes (two blocks plus a final token-to-image cross-attention), the decoder has refined both the tokens and the image features. Now, how do we actually produce a mask? The naive way is to have the decoder directly output a 256×256 map. SAM does something cleverer — a hypernetwork. Each mask token doesn't produce the mask itself; it produces the weights of a tiny filter that then gets applied to the image features. Step by step:
Step 1 — upsample the image features. The decoder's image features are still 64×64. To produce a high-res mask, SAM uses two transposed convolutions with stride 2, taking the features from (256, 64, 64) up to (32, 256, 256). The channel count drops to 32 — a smaller per-pixel feature, but at 4× the spatial resolution.
Step 2 — each mask token generates a filter. SAM has 4 mask tokens, each a 256-dim vector after the decoder. Each is passed through its own learnable MLP that outputs a 32-dim vector — the "filter weights" for that mask token.
Step 3 — apply the filter as a dot product. The upsampled features have shape (32, 256, 256) — 256×256 positions, each a 32-dim feature. The filter is also 32-dim. The mask at each pixel is the dot product of that pixel's feature with the filter:
where is the -th feature channel at pixel and is the -th filter weight. This yields one scalar per pixel — the mask logit. Do it for each of the 4 mask tokens and you get 4 mask maps of shape (256, 256). The hypernetwork design is far more efficient than directly outputting masks: the token-to-filter MLP has only ~256 × 32 ≈ 8,000 parameters per mask token, yet the same dot-product produces a full 256×256 mask — and each mask token can adapt its filter to the prompt, so the same image features can be queried with different filters to produce different masks.
Why four masks? Handling ambiguity
A single point click is often ambiguous — click on a person's shirt and do you mean the shirt, the torso, or the whole person? SAM sidesteps the ambiguity by predicting multiple masks (the 4 mask tokens) and letting the IoU token rank them. At training time, for an ambiguous prompt SAM only backpropagates through the best-matching of its 4 predictions (the one with the lowest mask loss against the ground truth), so only one mask token gets updated per ambiguous example. Over training this lets the four tokens specialize toward different interpretations (e.g. part vs whole-object), and at inference you can surface whichever the IoU head scores highest.
Full decoder forward pass, summarized: concatenate [IoU token, 4 mask tokens, sparse prompt embeddings] into the input token sequence; add the dense prompt to the image embedding (image_features = image_embedding + dense_prompt); run the two-way attention blocks (×2) so tokens and image features inform each other; do a final token-to-image cross-attention; separate out the IoU token and the 4 mask tokens; upsample the image features to (32, 256, 256); run hypernetwork mask generation (each mask token's MLP makes a 32-dim filter, dotted against the features to make a 256×256 mask); pass the IoU token through an MLP to predict each mask's quality. Output: 4 mask logits (256×256 each) and 4 predicted IoU scores. Threshold or argmax to get binary masks.
Check your understanding
Why does SAM predict four masks per prompt, and what role does the IoU token play?
Show answer ▸Hide answer ▾
To handle ambiguity. A single click can legitimately mean different things (the shirt, the torso, the whole person), so SAM outputs four candidate masks and, during training, only backprops through the best-matching one per ambiguous example — which lets the four mask tokens specialize toward different valid interpretations (part vs whole, etc.). The IoU token predicts the quality (estimated IoU) of each candidate mask, so at inference SAM can rank the four and surface the one it believes is best.
Training SAM: the data engine
The architecture is only half the story. The other half is the data engine that produced 1.1 billion training masks. SAM was trained on SA-1B (Segment Anything 1 Billion): 11 million diverse, high-resolution images (typically 1500×2250) and 1.1 billion segmentation masks — about 100 per image. For context, before SAM the biggest segmentation dataset was COCO with ~2.5 million masks; SA-1B is 400× larger. No human team could label that many masks manually, so SAM was trained through a model-in-the-loop data engine that bootstrapped its way up in three stages:
- Stage 1 — Assisted-Manual (4.3M masks)
- Annotators segmented objects with browser tools, helped by an early SAM. They clicked, SAM proposed masks, they refined. Started at ~34 seconds per mask; as SAM improved on this data, annotation sped up to ~14 seconds per mask. Several iterations produced 4.3 million masks across 120,000 images.
- Stage 2 — Semi-Automatic (5.9M masks)
- SAM now confidently masked "easy" objects on its own. It auto-detected prominent objects and annotators focused on adding the ones SAM missed — increasing diversity rather than re-covering obvious objects. Added 5.9 million masks across 180,000 images.
- Stage 3 — Fully Automatic (1.1B masks)
- SAM was now strong enough to annotate without humans. A regular 32×32 grid of points was placed on each image, SAM was prompted at each point, and the resulting masks were filtered for quality using confidence thresholds and stability metrics (running SAM with slightly perturbed prompts and keeping only masks that stayed consistent). This generated ~1.1 billion masks across 11 million images.
This bootstrapped engine is one of the biggest reasons SAM succeeded — the model and the dataset improved each other in a virtuous cycle that no fixed labeling budget could have matched.
The model labels its own ever-growing dataset: each stage retrains SAM on more masks, so annotation gets faster and more automatic — ending at 1.1B masks, about 400x bigger than COCO.
SAM's training loss
SAM's loss has two parts: a mask loss and an IoU-prediction loss.
Mask loss: focal + dice. Plain binary cross-entropy has two well-known problems for segmentation — class imbalance (most pixels are background; predicting all-zeros scores low loss but is useless) and easy-negative dominance (most background pixels are trivially easy, but multiplied by millions they swamp the loss from hard boundary pixels). Focal loss fixes both by down-weighting easy pixels:
where if else , is a class-balancing weight, and (typically 2) is the "focusing parameter." The term is small when the model is already confident and large when it's wrong, so hard examples dominate the gradient. Dice loss directly optimizes overlap:
where is the predicted probability, the true label, and a small constant for stability. Being a ratio of intersection to combined area, it's naturally robust to class imbalance. SAM combines them, weighting focal 20× more than dice:
The 20:1 ratio is a tuned hyperparameter; the intuition is that focal provides fine per-pixel boundary signal while dice provides a global overlap signal — both are needed.
IoU-prediction loss. SAM also trains the IoU token to predict each mask's quality, with a simple mean-squared error between predicted and actual IoU:
where is the number of predicted masks (typically 4) and is the actual IoU of mask with the ground truth. The combined objective is:
This trains SAM to do two things at once: produce accurate masks and honestly predict each mask's quality — both essential for interactive use, where the user only wants to see good masks. (Training details: AdamW with weight decay, a cosine learning-rate schedule with a 250-step linear warmup, no data augmentation — the dataset is so large and diverse it isn't needed — batch size 64 for SAM-Huge, and 270K iterations for SAM-Huge with fewer for smaller variants.)
Check your understanding
SAM's mask loss is . What does each piece contribute, and why is focal loss used over plain BCE?
Show answer ▸Hide answer ▾
Focal loss provides fine, per-pixel boundary signal and is used over plain binary cross-entropy because it down-weights the huge number of easy background pixels (via the factor), so hard boundary pixels actually drive the gradient instead of being swamped. Dice loss provides a global overlap signal that's robust to class imbalance (it's an intersection-over-combined-area ratio). Together — weighted 20:1 toward focal — they give both crisp local boundaries and good overall mask overlap, which neither alone delivers as well.
SAM 2: segmenting video
In 2024, Meta released SAM 2, extending SAM to video. The fundamental challenge with video is that objects move, deform, get occluded, and reappear — and a naive "run SAM on every frame" has no temporal consistency (a tracked dog might be segmented as the dog in frame 1 and the bush behind it in frame 2). SAM 2 solves this by adding a memory module that tracks each object's state across frames.
SAM 2 processes video as a stream — one frame at a time, in order — which matches how video is captured and lets it run in real time. Beyond SAM's components it adds: an image encoder (now a faster Hiera transformer instead of ViT-H) that encodes each frame; memory attention that modifies the current frame's embedding based on memories of previous frames; a memory encoder that encodes the predicted mask for the current frame as a memory feature for future use; and a memory bank storing past frame embeddings and mask features for the tracked object. The mask decoder is essentially SAM's. The per-frame flow: encode the frame; run memory attention so the frame's features attend to the memory bank (picking up where the object was and what it looked like before); decode the (memory-modified) features plus any prompts into this frame's mask; then encode that mask and add it to the memory bank for future frames. Nicely, when SAM 2 is applied to a single image, the memory bank is empty and the model behaves exactly like SAM — the memory components are simply bypassed, so one model handles both.
Check your understanding
What does SAM 2 add to handle video, and what happens to that machinery when you give it a single still image?
Show answer ▸Hide answer ▾
It adds a memory module: a memory bank storing past frames' embeddings and mask features, memory attention that lets the current frame's features attend to those memories, and a memory encoder that writes each predicted mask back into the bank. This gives temporal consistency, so a tracked object stays the same object across motion and occlusion. On a single still image the memory bank is empty, so the memory components are bypassed and SAM 2 behaves exactly like the original SAM.
Vision Transformers
For nearly a decade after AlexNet, convolutional neural networks owned vision. Every state-of-the-art classifier, detector, and segmentation model was built on the same CNN scaffolding — local convolutions, pooling, hierarchical feature maps. The architecture's inductive biases (locality, translation invariance, spatial hierarchy) seemed not just convenient but necessary for vision.
Meanwhile, transformers were eating language. By 2020, every important language model was a transformer, with two beautiful properties: it scaled almost arbitrarily well with data and parameters, and it imposed very little structure on its input — it just learned which patterns mattered. So the natural question: could transformers work for vision too? The CNN's biases helped enormously with limited data — but with internet-scale image datasets, maybe those same biases had become a ceiling. Maybe the right move was to hand the transformer raw image patches and let it figure out everything else.
Inductive bias versus scale
The answer, delivered by Google's Vision Transformer (ViT) paper in 2020, was: yes — if you have enough data. With a few million images (a typical academic-scale dataset), ViTs underperformed CNNs of comparable size. But with 300 million images (JFT-300M, Google's internal dataset), ViTs outperformed the best CNNs — and the gap widened with more data.
First, a definition: inductive bias is the set of assumptions an architecture builds in about its problem — the prior knowledge it uses to generalize to inputs it hasn't seen. CNNs have strong inductive biases: translation invariance, locality, hierarchy — assumptions about how vision works, baked into the wiring. With limited data those priors are gold, steering learning toward good solutions and saving the network from having to discover "nearby pixels are correlated" from scratch. ViTs have weak inductive biases: a ViT assumes almost nothing about images. It has to learn that nearby patches relate, that translation invariance helps, that hierarchies are useful — and with little data it can't learn all that, so it loses. But with enough data the situation flips: the CNN's biases become a straitjacket, while the transformer can discover patterns the CNN literally cannot represent (long-range relationships, attention-based pooling), and it pulls ahead.
This is the exact same scaling story that had already played out in language — architectures with stronger biases win at small scale, architectures with weaker biases but more parameters win at large scale. The transformer turned out to be the right "weak bias" architecture for vision, just as it had been for language. (Remember the inductive-bias-versus-scale theme from the language chapters? Same lesson, new domain.)
Turning an image into a sequence of patches
Transformers operate on sequences, so to use one on an image you first turn the image into a sequence. ViT does this by chopping the image into a grid of fixed-size patches, flattening each patch into a vector, and treating each patch as a token.
Take a 224×224 RGB image. Divide it into non-overlapping 16×16 patches. That gives a 14×14 grid (224/16 = 14), or 196 patches total. Each patch contains 16×16×3 = 768 raw pixel values. For each patch:
1. Flatten the patch into a 768-dimensional vector (just unroll the 16×16×3 values).
2. Project it through a learnable linear layer to produce a -dimensional patch embedding. For ViT-Base, (chosen so the projection is roughly identity-initialized, though it could be any value).
After this step, the image has become a sequence of 196 token embeddings, each of dimension — exactly the kind of input a transformer expects. Notice the parallel to SAM's image encoder: same patch-tokenization idea, because SAM's encoder is a ViT.
The CLS token and positional encodings
Two more ingredients before the transformer can run.
The [CLS] token is borrowed directly from BERT. ViT prepends a special learnable token at the start of the sequence — an extra slot the transformer can use to aggregate global information. After the transformer runs, this token's final hidden state is what gets passed to the classification head. Through attention, every other token can write information into the [CLS] token's representation, so it ends up as a learned global summary of the whole image.
Positional encodings give the model a sense of where each token sits. Without them, "dog on top, sky on bottom" would look identical to "sky on top, dog on bottom" — attention alone is permutation-invariant. ViT uses learnable positional embeddings: one trainable -dimensional vector per position, added to the patch embeddings. Position 0 (the [CLS] token) gets one positional vector, position 1 (the top-left patch) another, and so on through position 196. Here's the lovely part: ViT's experiments show the model learns 2D-aware positional embeddings from this 1D scheme — patches that are neighbors in the original image end up with similar positional vectors, even though the network was never told about the 2D layout. The transformer figures out the topology from data.
The transformer encoder stack
Now we have a sequence of 197 token embeddings. From here, ViT is a standard transformer encoder — identical to BERT, just operating on image tokens instead of word tokens. Each block does:
1. LayerNorm the input.
2. Multi-head self-attention — every token attends to every other token. This is where global reasoning happens.
3. Add the residual connection.
4. LayerNorm again.
5. MLP — a two-layer feed-forward network with GELU activation.
6. Add the residual connection.
If that block structure looks familiar, it should — it's the same pre-norm transformer block from the language chapters, residual connections and all (there's our recurring hero again, carrying gradients through the depth). ViT-Base stacks 12 such blocks, ViT-Large stacks 24, ViT-Huge stacks 32. After all the blocks you have an output sequence of 197 tokens, each a -dimensional vector now enriched by attention with every other token. For classification, ViT takes the final hidden state of the [CLS] token alone and passes it through a small MLP head that outputs class logits. That's it — the other 196 tokens are discarded for classification (though they're useful for dense tasks like segmentation, which is exactly how SAM uses them).
Making the sizes concrete
The standard ViT sizes, straight from the paper:
| Variant | Layers | Hidden dim | Heads | MLP dim | Parameters |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 12 | 3072 | 86M |
| ViT-Large | 24 | 1024 | 16 | 4096 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 5120 | 632M |
Three things to notice. The MLP dim is 4× the hidden dim — the standard transformer expansion ratio, where each block's feed-forward network projects up 4×, applies the nonlinearity, then projects back down; most of a transformer's parameters live in these MLPs. The hidden dim is divisible by the number of heads — each attention head operates on dimensions (for ViT-Base, 768/12 = 64 per head), the standard transformer convention. And parameters scale quadratically with width — ViT-Huge has ~4× the layers and ~1.5× the hidden dim of ViT-Base, but ~7× the parameters; that super-linear growth is why scaling transformers gets expensive fast.
Self-attention, and why it differs from convolution
The attention computation is identical to the transformer chapter, just over patch tokens. For a sequence of tokens ( for standard ViT):
The query, key, and value matrices are linear projections of the input token sequence:
where has shape and each matrix has shape ; is the scaling factor that keeps the dot products from growing too large (and flattening the softmax gradient — the same reason it's there in language). With multi-head attention this runs times in parallel with different matrices per head, then the results are concatenated.
The key contrast with a CNN: the attention pattern is genuinely all-to-all — every one of the 197 tokens attends to every other one. A single conv layer only sees a 3×3 neighborhood, so a ViT layer has a "receptive field" of the entire image immediately, in layer one. This is precisely the architectural difference that makes ViT good at global reasoning (no waiting for depth to grow the receptive field) and bad at small-data sample efficiency (it has to learn locality from scratch instead of getting it for free).
Computational cost
Self-attention has quadratic cost in sequence length. For 197 tokens you compute a 197×197 attention matrix per head per layer — manageable. But what if you want higher-resolution input? Drop the patch size from 16 to 8 and you get 784 patches — 4× the tokens, 16× the attention cost. Patch size 4: 16× the tokens, 256× the attention cost. This quadratic scaling is why ViT uses 16×16 patches by default — smaller patches give finer spatial precision, but the cost explodes. (This is the same quadratic-attention wall from the language chapters, and the fixes rhyme too: variants like Swin use windowed attention, exactly as SAM's encoder did, to claw back efficiency.)
Check your understanding
A ViT layer's attention is all-to-all while a conv layer sees only a 3×3 patch. Name one advantage and one disadvantage this gives the ViT.
Show answer ▸Hide answer ▾
Advantage: global reasoning from the very first layer — every patch can directly attend to every other, so long-range relationships (a person and the tree occluding them, opposite corners of an object) are available immediately, without waiting many layers for a receptive field to grow. Disadvantage: poor small-data sample efficiency — because the ViT bakes in almost no assumptions (no built-in locality or translation invariance), it must learn those useful priors from data, which takes a lot of data; a CNN gets them for free and so wins when data is scarce. It also costs more: all-to-all attention is quadratic in the number of patches.
Making ViTs practical: DeiT, Swin, MAE, DINO
The original ViT only beat CNNs with Google's proprietary JFT-300M. Four follow-up ideas made ViTs practical for everyone else.
DeiT (Data-efficient Image Transformer, 2021) was the breakthrough that made ViTs trainable on ImageNet alone. Its key ingredients: much stronger data augmentation (RandAugment, mixup, cutmix, random erasing — artificially expanding the training set), knowledge distillation via a separate "distillation token" that learns from a CNN teacher's predictions (so the student ViT learns from both the labels and the CNN's soft predictions), and better hyperparameters (careful tuning of learning rate, weight decay, dropout, stochastic depth). After DeiT, a competitive ViT no longer needed a proprietary dataset.
Swin Transformer kept the transformer but reintroduced CNN-like locality and hierarchy: a hierarchical structure with multiple stages of decreasing spatial resolution and increasing channel depth (just like a CNN); shifted-window attention, where attention runs within local windows (typically 7×7 patches) and the windows are shifted between layers so information flows across window boundaries over several layers; and linear complexity in image size, because attention is confined to local windows. Swin marries the transformer's flexibility with the CNN's efficiency, and it became the dominant ViT-style backbone for detection and segmentation where high resolution matters. (You met this exact idea inside SAM's image encoder.)
Local windows for cheap attention, shifted each layer so information still spreads — linear cost, global reach over depth.
ViT's biggest impact may actually be on self-supervised learning — training without labels.
MAE (Masked Autoencoder, 2021) is the vision version of BERT. The procedure: randomly mask 75% of the patches; the encoder sees only the visible 25%; a small decoder receives the encoder's output plus mask tokens at the missing positions and reconstructs the original pixels. After training, the decoder is thrown away and the encoder becomes a feature extractor. The aggressive 75% masking is the key — with so much hidden, the encoder must learn rich representations to enable reconstruction, so it learns generic visual understanding before any labels are involved. (This is exactly the pretraining SAM's ViT-H encoder used.)
DINO (self-Distillation with NO labels, 2021) uses self-distillation: two networks, a student and a teacher, look at different augmented views of the same image, and the student is trained to match the teacher's representations. The teacher is an exponential moving average of the student's weights — they're never trained as separate models. DINO produces remarkably semantic features without any supervision: visualize what its attention attends to and you see object boundaries, foreground/background separation, and focus on specific objects — all learned label-free. DINOv2 (Meta's follow-up) is now a general-purpose vision feature extractor across many downstream applications.
These self-supervised approaches matter because they unlock the real scaling promise: you can train ViTs on billions of unlabeled internet images rather than the much smaller labeled datasets, producing far stronger pretrained features.
Check your understanding
MAE masks 75% of patches and reconstructs them; what is the encoder actually left with at the end, and why is such aggressive masking the point?
Show answer ▸Hide answer ▾
After training, the decoder is discarded and the encoder is kept as a general-purpose feature extractor. The aggressive 75% masking is what forces the encoder to learn rich, semantic representations: if only a little were hidden, the model could reconstruct from low-level local texture alone, but with three-quarters of the image missing it has to understand global structure and context to fill in the gaps — so it learns genuinely useful visual features, all without labels. It's the same pretraining recipe SAM used for its ViT-H encoder.
Vision-Language Models
A vision-language model (VLM) takes both images and text as input and generates text as output. "Describe this image." "What's the dog doing?" "Read the receipt and total the items." VLMs are how GPT-4V, Claude, and Gemini handle vision.
The architectural question is: how do you let a language transformer see images? You can't just stuff raw pixels into an LLM — LLMs operate on token sequences. You need to convert images into something that fits naturally alongside text tokens. The answer that won: use a ViT to encode the image into tokens, then mix those visual tokens into the LLM's token stream. The LLM treats visual tokens just like text tokens — it attends to them, processes them through its layers, and generates text conditioned on all of them.
This is the unified abstraction that makes VLMs work, and it's the punchline this whole chapter has been building toward: anything you can tokenize can become input to an LLM. Images become patch tokens. Audio becomes audio tokens. Video becomes spatiotemporal tokens. The transformer is modality-agnostic — once everything is tokens, attention does the rest. (This is the same "tokenize everything" theme from the language chapters, now extended past text.)
The standard VLM recipe
A modern VLM has four components:
1. Vision encoder — typically a ViT, often pretrained with CLIP or SigLIP. Takes an image and produces a sequence of patch tokens.
2. Projection layer / adapter — a small network mapping visual tokens from the encoder's dimension into the LLM's input space. Usually just an MLP.
3. Language model — a standard decoder-only LLM (Llama, Qwen, etc.). Sees a sequence of mixed tokens: text tokens + projected visual tokens.
4. Tokenizer — handles the text side as usual; image tokens are inserted into the sequence at the right positions.
At inference, the typical flow: the user provides an image and a prompt ("What is in this image?"); the vision encoder turns the image into patch tokens (e.g. 256 tokens from a 224×224 image at patch size 14); the projection layer maps each patch token into the LLM's hidden dimension; the prompt is tokenized normally; the final input sequence is [image_tokens][text_tokens], all in the same hidden dim; and the LLM generates its response autoregressively, attending to both visual and text tokens.
Hover any stage to see its job. Run the flow to watch the image become patch tokens, get projected, concatenate with the question, and the LLM emit an answer token by token.
Why start from pretrained CLIP/SigLIP and a pretrained LLM rather than training everything from scratch? Data efficiency. Pretrained CLIP/SigLIP encoders have already learned semantic visual representations from billions of image-text pairs — their patch tokens are organized so semantically similar things have similar embeddings, and so captions and images live in compatible spaces. When you train a VLM you don't have billions of examples — you have maybe millions of (image, instruction, response) triples. Starting from a pretrained encoder means you're not learning vision from scratch, just teaching the LLM to interpret already-meaningful features. The same logic applies to the LLM: it already knows language, so the training task is mainly to teach it how to use the new visual modality.
Connector designs: how visual tokens get in
The "projection" or "adapter" between vision encoder and LLM is small but architecturally significant. Three main designs:
Simple linear projection (LLaVA-style). Just a linear layer (or MLP) from the encoder's hidden dim to the LLM's hidden dim. Each ViT patch token becomes one LLM token. Cheap, simple, surprisingly effective. The downside: a ViT producing 256 tokens spends 256 tokens of LLM context on the image, which gets expensive for long-image-context scenarios.
Q-Former (BLIP-2 style). A small transformer with a fixed number of learnable "query tokens" (e.g. 32) that learn what to extract from the encoder. The Q-Former cross-attends to the encoder's outputs and produces a fixed-size summary, so the LLM only ever sees those 32 tokens regardless of image resolution. Trades some flexibility for context efficiency.
Cross-attention layers (Flamingo style). Instead of treating image tokens as additional input tokens, add new cross-attention layers throughout the LLM that attend to the encoder's output. Visual features don't consume context tokens; they're accessed via dedicated attention. More compute, but a cleaner separation between modalities.
For most open-source VLMs today, simple linear projection (or a 2-layer MLP) is the default — LLaVA's success showed the simplest design works well enough that the sophistication of Q-Formers and cross-attention usually isn't necessary.
CLIP and SigLIP: the vision encoders
These are worth recapping in the VLM context, because the choice of encoder strongly shapes VLM quality.
CLIP (Contrastive Language-Image Pre-training, OpenAI 2021) was the first vision encoder explicitly designed to align with text. The procedure: collect ~400M image-text pairs from the internet; train an image encoder and a text encoder jointly; for each batch of pairs, compute an similarity matrix between every image embedding and every text embedding; then use a contrastive loss that maximizes the diagonal (matched pairs) and minimizes the off-diagonal (mismatched pairs). The objective is a softmax across the batch:
where is cosine similarity, (tau) is a learned temperature scaling the sharpness of the distribution, is the caption matching image , and ranges over all captions in the batch. The result: CLIP's image and text encoders produce embeddings in a shared semantic space — a photo of a dog and the caption "a photo of a dog" land near each other; a photo of a cat lands far from "a dog."
SigLIP (Sigmoid Language-Image Pre-training, Google 2023) made one elegant change: replace the softmax-over-batch loss with a per-pair sigmoid loss:
where is the sigmoid function and is for matched pairs and for unmatched pairs. Each pair is treated independently as a binary "is this a match?" question. The benefits all flow from removing the batch-wide normalization: there's no softmax across the whole batch (which is what forced CLIP to use enormous batches), so SigLIP trains effectively at modest batch sizes where CLIP needed 32K+ to compete, and it keeps improving at huge batch sizes past the point where CLIP plateaus. In 2024-2025, SigLIP encoders became the default in modern open VLMs (PaLI-Gemma, Gemini, Idefics2, InternVL) — a better encoder yields a consistently better downstream VLM.
Click a cell to compare how CLIP (row-normalized) and SigLIP (independent) score that image–text pair.
VLM training strategy
A modern VLM is trained in stages, each unfreezing a bit more of the model:
Stage 1 — Pretrained components. Start with a pretrained vision encoder (CLIP/SigLIP) and a pretrained LLM (Llama, Qwen, etc.), both frozen.
Stage 2 — Adapter pretraining. Train only the adapter (the linear projection or Q-Former); the encoder and LLM stay frozen. The objective is image-text alignment on simple tasks like captioning. Cheap, and it teaches the adapter to map vision features into the LLM's space.
Stage 3 — Instruction tuning. Unfreeze the LLM (and sometimes the encoder) and fine-tune on instruction-following data: image-question-answer triples covering everything from "what's in this image" to "read this receipt" to "describe the scene." This teaches the model to use its multimodal understanding for real tasks.
Stage 4 — Optional preference alignment. RLHF, DPO, or similar to align outputs with human preferences — the same alignment step as text-only LLMs.
The data pipeline matters enormously: diverse, high-quality, well-grounded instruction-tuning data is what separates a usable VLM from a curiosity.
Freezing the expensive pretrained encoder and LLM while training only the small adapter is cheap; unfreezing the LLM later buys capability at higher cost. The vision encoder usually stays frozen throughout.
Multi-image and video
The same architecture extends naturally. Multi-image VLMs pass each image through the vision encoder independently, then concatenate all the image tokens into the sequence, so the LLM can reason across several images at once ("which of these dogs is biggest?"); position information indicates which image each token came from, often via special "image start"/"image end" tokens. Video VLMs take one of two approaches: frame sampling (treat the video as a sequence of frames, encode each independently, concatenate all the tokens — simple but context-hungry) or spatiotemporal encoders (a video-specific encoder like ViViT or VideoMAE that processes the whole clip at once, attending across both space and time — more efficient). Modern frontier VLMs (Gemini, GPT-4V, Claude) can process hours of video by tokenizing cleverly: keeping important frames in detail while downsampling redundant stretches.
What VLMs can and can't do
Because the LLM is general-purpose, VLMs handle a wide range of tasks out of the box: visual question answering ("what color is the car?"), image captioning, OCR (reading text in images), chart/document understanding (tables, graphs, forms), visual reasoning (counting, comparing sizes, spatial relations), and instruction following with visual context ("edit the third item in this list").
But they have real weaknesses: fine-grained spatial reasoning is often poor ("exactly where is the cat's left paw?"); precise counting of many small objects is unreliable; reading rotated or unusual fonts can fail; and hallucination is a live risk — VLMs can confidently describe things that aren't in the image, especially when it's ambiguous or unusual. The pattern: VLMs are strong at semantic understanding (what's in the image, what's happening, what it means) and weaker at precise understanding (exact positions, counts, fine details). That makes sense given the training data — captions and natural-language descriptions are themselves usually semantic rather than precise. (And notice the contrast with SAM: when you need pixel-precise where, a promptable segmentation model is the right tool; when you need semantic what, the VLM shines. Different tools for the two halves of vision.)
Check your understanding
What single idea lets a language-only transformer suddenly handle images, and why do VLMs start from a pretrained CLIP/SigLIP encoder rather than a random one?
Show answer ▸Hide answer ▾
The single idea: turn the image into tokens (via a ViT) and project them into the LLM's hidden dimension so they sit in the same sequence as text tokens — then ordinary attention reasons over both. Because everything is tokens, the LLM is modality-agnostic. VLMs start from a pretrained CLIP/SigLIP encoder for data efficiency: those encoders already learned semantic, text-aligned visual features from billions of image-text pairs, whereas a VLM is fine-tuned on only millions of triples — far too few to learn vision from scratch. Starting pretrained means you only teach the LLM to interpret already-meaningful features.
Putting It All Together
Step back and look at the whole road, because — exactly like the language story — every stop on it was a fix for the thing before it.
We started by asking what an image even is (a tensor of numbers) and tried the obvious thing: feed pixels to an MLP. It blew up — too many parameters, no translation invariance — and that failure handed us the CNN, built around convolution's two gifts: locality and weight-shared translation invariance. Stack conv-pool blocks and a feature hierarchy emerges on its own — edges to textures to parts to objects — as receptive fields widen with depth. Three ideas turned CNNs from a curiosity into a dynasty: ReLU, BatchNorm, and residual connections — the last being the very same skip-connection trick that makes transformers trainable.
Then we put CNNs to work. YOLO reframed detection from a slow stack of classifiers (DPM, R-CNN) into one end-to-end regression over a grid — you only look once. Segmentation pushed to per-pixel labels, which forced a reckoning with the resolution-versus-semantics tension: FCN upsampled but came out blurry, and U-Net solved it cleanly with skip connections carrying sharp spatial detail across the U while semantics flowed through the deepest point. Mask R-CNN added instances by decoupling "what" (a class head) from "which pixels" (a per-class sigmoid mask). And SAM turned segmentation into a foundation model — promptable, trained on a billion bootstrapped masks, with a heavy ViT encoder, a featherweight prompt encoder, and a clever two-way-attention mask decoder — the same "one promptable model, the prompt encodes the task" move GPT made in language.
Finally we came full circle. ViT dropped convolution entirely, chopped the image into patches, and fed them to a plain transformer — winning once data was large enough, the same inductive-bias-versus-scale story as language. DeiT, Swin, MAE, and DINO made ViTs practical and label-free. CLIP and SigLIP aligned visual features with text. And VLMs tied the whole field to the language chapters with one punchline: tokenize the image, project it into the LLM's space, and let attention do the rest. Anything you can tokenize, the transformer can reason over.
The chapter is one connected journey: each milestone removes a limitation of the last, while skip/residual connections and attention plus “tokenize everything” keep reappearing as the load-bearing ideas.
Check your understanding
Two ideas recur across this entire chapter, linking it back to the language chapters. Name them, and give two places each shows up in vision.
Show answer ▸Hide answer ▾
First, skip / residual connections. They appear as ResNet's residual blocks (the third of the three big CNN ideas, letting gradients flow through deep stacks), as U-Net's skip connections (carrying sharp spatial detail across the U so masks aren't blurry), and again inside every ViT block's residual adds — the same trick that makes transformers trainable in the language chapters. Second, attention plus the "tokenize everything" idea. It shows up in SAM's ViT image encoder and two-way-attention decoder, in the Vision Transformer itself (all-to-all patch attention replacing convolution), and in VLMs (image patches tokenized and fed into an LLM alongside text, where attention reasons over both). The throughline of the whole chapter: turn a modality into tokens, and attention — with residual connections keeping the deep stack trainable — does the rest.