ViT model

Putting together patch embeddings and transformer encoder
CONFIG_PATH = '../config.yml'
DATA_PATH = Path('../input')

Load parameters from the config file.

config = yaml.safe_load(open(CONFIG_PATH))
dset = datasets.CIFAR10(DATA_PATH, download=True)
images, targets =, dset.targets
len(images), len(targets)
(50000, 50000)

Prepare a small batch of images to test the image processing.

(50000, 32, 32, 3)

Sample a bunch of points and select those as indices of the image for training.

image_idx = np.random.randint(low=0, high=len(images), size=3)
# corresponding labels
targets = [targets[t] for t in image_idx]
[3, 6, 2]
n_classes = config["model"]["n_classes"]

Putting together PatchEmbedding and TransformerEncoder

images = torch.Tensor(images[image_idx])
images = images/255.
hw = config['data']['hw']
augs = T.Resize(hw)

images = augs(images.permute(0, 3, 1, 2))
torch.Size([3, 3, 224, 224])



 VisionTransformer (config)

vit = VisionTransformer(config)
outs = vit(images)
torch.Size([3, 10])
torch.Size([3, 196, 768])
torch.Size([3, 768])