Project 4 — Train a CNN image classifier
Language isn't the only modality. Convolutional neural networks are the workhorse of computer vision — medical imaging, self-driving perception, quality inspection, face recognition. This project trains one from scratch, end to end, on synthetic images (no downloads), and watches it climb from coin-flip to 100% accuracy in a few seconds. You'll see a CNN learn the filters that detect visual patterns — exactly how real vision models learn edges and textures.
Full code: code/projects/cnn.py (PyTorch, CPU-friendly).
The task
Classify 12×12 grayscale images: class 0 has a horizontal bar, class 1 has a vertical bar, both buried in noise. Trivial for a human, impossible for a model that can't perceive spatial structure — which is exactly what a CNN is built to do.
def make_images(n):
X = 0.1 * torch.randn(n, 1, H, W) # noise
y = torch.randint(0, 2, (n,))
for i in range(n):
if y[i] == 0: X[i, 0, <random row>, :] += 1.0 # a horizontal bar
else: X[i, 0, :, <random col>] += 1.0 # a vertical bar
return X, y
The bar is on a random row/column each time, so the model can't memorize a position — it must learn the concept of horizontal vs. vertical. That's the whole point of a CNN's translation invariance (Chapter 13): detect a pattern wherever it appears.
The architecture
A CNN stacks convolution → activation → pooling to build up from pixels to concepts (Chapter 13):
class CNN(nn.Module):
def __init__(self):
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) # 8 learned 3×3 filters
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2) # downsample 12→6→3
self.fc = nn.Linear(16 * 3 * 3, 2) # classify
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # detect features, then shrink
x = self.pool(F.relu(self.conv2(x)))
return self.fc(x.flatten(1))
Conv2d— the learnable kernels/filters from Chapter 5. The first layer's 8 filters learn to respond to simple patterns (horizontal vs. vertical edges); training discovers them — you don't hand-design them.- ReLU — the non-linearity (Chapter 11).
- MaxPool — keeps the strongest response in each 2×2 region, shrinking the image and adding position-tolerance.
- Linear head — flattens the final feature map and classifies (Chapter 6).
Training it
Standard mini-batch training loop (Chapter 16) with
model.train() / model.eval() (Chapter 12):
$ python cnn.py
Output:
params=1538
epoch 0 loss 0.671 test-acc 0.545
epoch 1 loss 0.616 test-acc 0.775
epoch 2 loss 0.520 test-acc 0.970
epoch 3 loss 0.381 test-acc 0.985
epoch 4 loss 0.233 test-acc 1.000
epoch 5 loss 0.127 test-acc 1.000
A 1,538-parameter CNN started at 54.5% (barely better than guessing) and reached 100% test accuracy by epoch 4 — having learned convolutional filters that detect horizontal vs. vertical structure, evaluated on held-out images it never saw (Chapter 9). That's a complete, honest computer-vision training run.
Scaling to real images
The architecture is identical for real vision; you just grow it and feed it real data:
- Real datasets — MNIST (digits), CIFAR-10 (objects), ImageNet (1000 classes). One
line with
torchvision.datasets. - Deeper networks — ResNet, EfficientNet, U-Net: the same conv/pool/activation pattern, dozens of layers, residual connections (Chapter 11).
- Data augmentation — random flips/crops/rotations to fight overfitting (Chapter 9).
You usually don't train from scratch — transfer learning
For real vision tasks, you rarely start from random weights. You take a CNN pretrained on ImageNet (which already learned edges, textures, shapes) and fine-tune it on your task — transfer learning, the vision cousin of Chapter 32's fine-tuning:
# pip install torchvision
from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT) # pretrained on ImageNet
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES) # swap the head for your task
# freeze the body, train the head on your (small) dataset — fast, data-efficient
This is how you build a strong image classifier with a few hundred examples instead of a million: stand on the shoulders of a model that already learned to see.
Don't be confused: CNN vs. Vision Transformer (ViT). CNNs were the default for a decade; Vision Transformers now match or beat them at scale by applying the transformer to image patches. CNNs remain more data-efficient and are still everywhere; ViTs win with huge data. Know both exist.
Make it production
- Serve it behind an API; preprocess images identically in training and serving (resize, normalize) to avoid skew (the tools book).
- Monitor for drift — new camera, new lighting → input distribution shifts (Chapter 12 / the tools book's monitoring chapter).
- Optimize with ONNX/quantization for fast inference on edge devices (the tools book's ONNX chapter).
The takeaway
A CNN learns convolutional filters that detect visual patterns wherever they appear; you trained one from 54% to 100% in seconds, then saw how transfer learning (a pretrained ResNet + a new head) builds strong classifiers from little data. Same conv/pool/activation pattern, scaled up, runs all of computer vision. One project left, and it's the most modern: generating new data with diffusion. 👉