Skip to content

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch for small datasets like MNIST, FashionMNIST, SVHN and CIFAR10 with detailed steps.

Notifications You must be signed in to change notification settings

s-chh/PyTorch-Scratch-Vision-Transformer-ViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer from Scratch in PyTorch

Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with Detailed Steps (Refer to model.py)

This repo uses a smaller ViT for small-scale datasets like MNIST, CIFAR10, etc., using a smaller patch size.

Key Points:

  • ViT used in a scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 86.3 (82.5 w/o RandAug)
CIFAR100 python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 59.6 (55.8 w/o RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1

To train Vision Transformer with a different type of position embeddings, check out Positional Embeddings for Vision Transformers

About

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch for small datasets like MNIST, FashionMNIST, SVHN and CIFAR10 with detailed steps.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published