PyTorch reimplementation of a ViT-B/16-style classifier trained from scratch on ImageNet-style data, plus a figure script that reproduces several visualizations from the original Vision Transformer paper.
- Best reported ImageNet top-1 accuracy from
vit.py:75.14% - Main training entrypoint:
vit.py - Main visualization entrypoint:
vit_figures/vit_visualizations.py
python -m venv .venv
# Windows
.venv\Scripts\activate
# macOS / Linux
# source .venv/bin/activate
pip install -r requirements.txtpython vit_figures/vit_visualizations.pyThis uses the images in sample_images/, downloads pretrained checkpoints when needed, and writes outputs into vit_figures/.
vit.py expects an ImageFolder-style dataset. By default it looks for:
data/imagenet/train
data/imagenet/val
You can override the paths with:
IMAGENET_TRAIN_DIRIMAGENET_VAL_DIRVIT_CHECKPOINT_PATHVIT_TENSORBOARD_DIR
Run training with:
python vit.pyOutputs are written to outputs/checkpoints/ and outputs/runs/ by default.
vit.py is not only a minimal architecture demo. It contains the full training recipe that produced the reported 75.14% top-1 result:
- ViT-B/16-style patch embedding, class token, learned positional embeddings, and transformer encoder blocks
- RandAugment, Mixup, CutMix, and Random Erasing for stronger data augmentation
- label smoothing and stochastic depth for regularization
- AdamW with warmup + cosine decay
- AMP, gradient accumulation, and gradient clipping for stable practical training
- checkpoint resume and early stopping for long runs
The repository keeps the whole model and training loop in one file on purpose so the architecture and optimization choices are easy to inspect.
vit_figures/vit_visualizations.py reproduces several classic ViT analyses, including embedding filters, positional similarity, mean attention distance, and attention rollout.
This figure makes the model behavior easier to read than raw predictions alone. The attention rollout highlights where the output token is focusing in the image, and in these examples the strongest responses stay on the main object while the background is suppressed.
This view shows what the patch embedding layer has learned at the input stage. It gives a quick intuition for the low-level textures and directional patterns the model uses before self-attention mixes global information.
This heatmap visualizes similarity between learned positional embeddings. Nearby regions tend to stay more related than distant ones, which makes the spatial inductive pattern in the learned representation easier to see.
This plot summarizes how far each attention head tends to look across layers. It is a compact way to see that different heads operate at different spatial ranges instead of all behaving the same way.
MIT



