Project 4 — Train a CNN image classifier

Language isn't the only modality. Convolutional neural networks are the workhorse of computer vision — medical imaging, self-driving perception, quality inspection, face recognition. This project trains one from scratch, end to end, on synthetic images (no downloads), and watches it climb from coin-flip to 100% accuracy in a few seconds. You'll see a CNN learn the filters that detect visual patterns — exactly how real vision models learn edges and textures.

Full code: code/projects/cnn.py (PyTorch, CPU-friendly).

The task

Classify 12×12 grayscale images: class 0 has a horizontal bar, class 1 has a vertical bar, both buried in noise. Trivial for a human, impossible for a model that can't perceive spatial structure — which is exactly what a CNN is built to do.

def make_images(n):
    X = 0.1 * torch.randn(n, 1, H, W)          # noise
    y = torch.randint(0, 2, (n,))
    for i in range(n):
        if y[i] == 0: X[i, 0, <random row>, :] += 1.0     # a horizontal bar
        else:         X[i, 0, :, <random col>] += 1.0     # a vertical bar
    return X, y

The bar is on a random row/column each time, so the model can't memorize a position — it must learn the concept of horizontal vs. vertical. That's the whole point of a CNN's translation invariance (Chapter 13): detect a pattern wherever it appears.

The architecture

A CNN stacks convolution → activation → pooling to build up from pixels to concepts (Chapter 13):

class CNN(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)    # 8 learned 3×3 filters
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)                               # downsample 12→6→3
        self.fc = nn.Linear(16 * 3 * 3, 2)                        # classify

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))    # detect features, then shrink
        x = self.pool(F.relu(self.conv2(x)))
        return self.fc(x.flatten(1))
  • Conv2d — the learnable kernels/filters from Chapter 5. The first layer's 8 filters learn to respond to simple patterns (horizontal vs. vertical edges); training discovers them — you don't hand-design them.
  • ReLU — the non-linearity (Chapter 11).
  • MaxPool — keeps the strongest response in each 2×2 region, shrinking the image and adding position-tolerance.
  • Linear head — flattens the final feature map and classifies (Chapter 6).

Training it

Standard mini-batch training loop (Chapter 16) with model.train() / model.eval() (Chapter 12):

$ python cnn.py

Output:

params=1538
epoch 0  loss 0.671  test-acc 0.545
epoch 1  loss 0.616  test-acc 0.775
epoch 2  loss 0.520  test-acc 0.970
epoch 3  loss 0.381  test-acc 0.985
epoch 4  loss 0.233  test-acc 1.000
epoch 5  loss 0.127  test-acc 1.000

A 1,538-parameter CNN started at 54.5% (barely better than guessing) and reached 100% test accuracy by epoch 4 — having learned convolutional filters that detect horizontal vs. vertical structure, evaluated on held-out images it never saw (Chapter 9). That's a complete, honest computer-vision training run.

Scaling to real images

The architecture is identical for real vision; you just grow it and feed it real data:

  • Real datasets — MNIST (digits), CIFAR-10 (objects), ImageNet (1000 classes). One line with torchvision.datasets.
  • Deeper networks — ResNet, EfficientNet, U-Net: the same conv/pool/activation pattern, dozens of layers, residual connections (Chapter 11).
  • Data augmentation — random flips/crops/rotations to fight overfitting (Chapter 9).

You usually don't train from scratch — transfer learning

For real vision tasks, you rarely start from random weights. You take a CNN pretrained on ImageNet (which already learned edges, textures, shapes) and fine-tune it on your task — transfer learning, the vision cousin of Chapter 32's fine-tuning:

# pip install torchvision
from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)     # pretrained on ImageNet
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)  # swap the head for your task
# freeze the body, train the head on your (small) dataset — fast, data-efficient

This is how you build a strong image classifier with a few hundred examples instead of a million: stand on the shoulders of a model that already learned to see.

Don't be confused: CNN vs. Vision Transformer (ViT). CNNs were the default for a decade; Vision Transformers now match or beat them at scale by applying the transformer to image patches. CNNs remain more data-efficient and are still everywhere; ViTs win with huge data. Know both exist.

Make it production

The takeaway

A CNN learns convolutional filters that detect visual patterns wherever they appear; you trained one from 54% to 100% in seconds, then saw how transfer learning (a pretrained ResNet + a new head) builds strong classifiers from little data. Same conv/pool/activation pattern, scaled up, runs all of computer vision. One project left, and it's the most modern: generating new data with diffusion. 👉