= '../config.yml'
CONFIG_PATH = Path('../input') DATA_PATH
Prepare patches
Load parameters from the config file.
= yaml.safe_load(open(CONFIG_PATH)) config
= datasets.CIFAR10(DATA_PATH, download=True, train=True) dset
Files already downloaded and verified
= dset.data, dset.targets
images, targets len(images), len(targets)
(50000, 50000)
Prepare a small batch of images to test the image processing.
images.shape
(50000, 32, 32, 3)
Sample a bunch of points and select those as indices of the image for training.
= np.random.randint(low=0, high=len(images), size=3) image_idx
# corresponding labels
= [targets[t] for t in image_idx]
targets targets
[3, 2, 4]
What are the classes we are dealing with?
# dict2obj lets you dotify dicts or nested dicts
= dict2obj(dset.class_to_idx)
clsidx clsidx.bird
2
# filter the dict based on a function that checks k and v
lambda k,v: v in targets) filter_dict(clsidx,
{'bird': 2, 'cat': 3, 'deer': 4}
# filters based on just the values
lambda v: v in targets) filter_values(clsidx,
{'bird': 2, 'cat': 3, 'deer': 4}
0] targets[
3
clsidx
{ 'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9}
= config["patch"]["in_ch"]
in_ch = config["patch"]["out_ch"] out_ch
# size of each small patch
= config['patch']['size']
patch_size patch_size
16
1:] images.shape[
(32, 32, 3)
= torch.Tensor(images[image_idx])
images = images/255.
images images.shape
torch.Size([3, 32, 32, 3])
Increase image size to match with ViT paper \(224\times 224\)
= config['data']['hw']
hw = T.Resize(hw)
augs augs
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
= augs(images.permute(0, 3, 1, 2))
images images.shape
torch.Size([3, 3, 224, 224])
= images.shape[1:]
n_channels, height, width print(f"image height: {height}, width: {width}, channels: {n_channels}")
assert height==width
image height: 224, width: 224, channels: 3
Number of patches is also called the sequence length in the original Transformers paper.
= (height*width)/(patch_size**2)
n_patches print(f"number of {patch_size}x{patch_size} patches in an image of shape {images.shape[1:]}: {n_patches}")
number of 16x16 patches in an image of shape torch.Size([3, 224, 224]): 196.0
= (n_patches, (patch_size**2)*in_ch)
shape_sequence print(f"shape of flattened 2D sequence: {shape_sequence}")
shape of flattened 2D sequence: (196.0, 768)
Prepare flattened 2D sequence
Display a sample image with title.
=2 idx
=(2, 2))
plt.figure(figsize1, 2, 0))
plt.imshow(images[idx].permute('off')
plt.axis(= filter_values(clsidx, lambda v: v is targets[idx])
label =label) plt.title(label
Text(0.5, 1.0, "{'deer': 4}")
Use a convolutional layer to prepare a patched image.
= nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=patch_size, stride=patch_size) conv2d
PyTorch requires images in BCHW format.
images.shape
torch.Size([3, 3, 224, 224])
= conv2d(images)
patched_image patched_image.shape
torch.Size([3, 768, 14, 14])
= patched_image.flatten(start_dim=2, end_dim=-1)
patched_image # flatten to 2D
patched_image.shape # permute to match the expected shape
= patched_image.permute(0, 2, 1)
patched_image patched_image.shape
torch.Size([3, 196, 768])
Prepare class embedding
The goal is to prepend to the flattened sequence a new item that encodes the class. Inorder to create class embedding, we need to know the embedding dimension as well since each of the embeddings is associated to the class as well.
= patched_image.shape
bs, seq_len, embed_dim print(f"shape of embed dim (D): {embed_dim}")
print(f"shape of sequence length (N): {seq_len}")
shape of embed dim (D): 768
shape of sequence length (N): 196
# for all patches in the image prepend a learnable class token
# this is a common token, not separated by items in batch
= nn.Parameter(torch.ones(1, 1, embed_dim), requires_grad=True)
class_token class_token.shape
torch.Size([1, 1, 768])
class_token.shape, patched_image.shape
(torch.Size([1, 1, 768]), torch.Size([3, 196, 768]))
# class tokens are shared among batch items
= repeat(class_token, '1 1 d -> bs 1 d', bs = bs)
class_tokens class_tokens.shape
torch.Size([3, 1, 768])
=1) torch.concat([class_tokens, patched_image], dim
tensor([[[ 1.0000e+00, 1.0000e+00, 1.0000e+00, ..., 1.0000e+00,
1.0000e+00, 1.0000e+00],
[-1.9330e-01, 5.3230e-01, 4.1072e-01, ..., -7.1345e-02,
-3.6855e-02, -5.7116e-03],
[-2.3731e-01, 6.1622e-01, 4.8126e-01, ..., -3.7821e-02,
-2.9133e-02, -4.7400e-02],
...,
[-8.0137e-02, 1.7678e-01, 1.2428e-01, ..., -3.5896e-02,
-3.9498e-02, 9.8821e-03],
[-5.6294e-02, 1.7154e-01, 9.5655e-02, ..., -4.0395e-02,
-2.8825e-02, -8.3997e-04],
[-5.4865e-02, 1.5447e-01, 8.8190e-02, ..., -2.9878e-02,
-3.1163e-02, -7.8464e-04]],
[[ 1.0000e+00, 1.0000e+00, 1.0000e+00, ..., 1.0000e+00,
1.0000e+00, 1.0000e+00],
[-2.1967e-02, 1.6781e-01, 1.1162e-01, ..., -1.2469e-01,
-8.1662e-02, 1.2122e-01],
[-1.9772e-02, 1.7202e-01, 1.1616e-01, ..., -1.3429e-01,
-8.4445e-02, 1.2937e-01],
...,
[-2.6714e-01, 4.9820e-01, 4.5177e-01, ..., 8.2063e-02,
5.1341e-02, -2.3969e-01],
[-9.8749e-02, 1.2397e-01, 9.3057e-02, ..., 3.5830e-02,
-2.1807e-03, -5.0337e-02],
[-9.7703e-02, 2.9660e-01, 1.7762e-01, ..., 4.5148e-02,
-2.6980e-03, -9.4421e-02]],
[[ 1.0000e+00, 1.0000e+00, 1.0000e+00, ..., 1.0000e+00,
1.0000e+00, 1.0000e+00],
[-1.3643e-02, 7.4828e-02, 1.8845e-02, ..., 3.0374e-03,
-2.4055e-02, -1.7881e-02],
[-6.4193e-02, 1.5682e-01, 1.2541e-01, ..., -4.4255e-03,
-1.8938e-02, -4.9337e-02],
...,
[-4.5829e-02, 1.6316e-01, 1.0169e-01, ..., -7.5816e-03,
-1.7019e-02, -3.2718e-02],
[-9.3686e-02, 2.6215e-01, 1.8407e-01, ..., 1.8344e-02,
-8.4813e-03, -5.2504e-02],
[-9.8182e-02, 2.9226e-01, 1.7385e-01, ..., 2.4724e-02,
1.0079e-02, -8.5595e-02]]], grad_fn=<CatBackward0>)
= torch.concat([class_tokens, patched_image], dim=1)
patched_image patched_image.shape
torch.Size([3, 197, 768])
Prepare positional embedding
This should now be same for all images in the batch. Add positional embedding for each patch.
# +1 for the extra class token added above
= nn.Parameter(torch.ones(1, int(seq_len)+1, embed_dim), requires_grad=True)
pos_token pos_token.shape
torch.Size([1, 197, 768])
Add the positional embedding to create the input \(z_l^0\)
+= pos_token patched_image
PatchEmbedding Module
PyTorch module that does all of the above.
PatchEmbedding
PatchEmbedding (config, channel_first=True)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
=True)(images).shape PatchEmbedding(config, channel_first
torch.Size([3, 197, 768])