A class for creating a custom Vision Transformer (ViT) model for visual recognition

class VisRecTrans[source]

VisRecTrans(model_name, num_classes, pretrained=True)

Class for setting up a vision transformer for visual recognition. Returns a pretrained custom ViT model for the given model_name and num_classes, by default, or, with randomly initialized parameters, if pretrained is set to False.

VisRecTrans.create_model[source]

VisRecTrans.create_model()

Method for creating the model.

VisRecTrans.initialize[source]

VisRecTrans.initialize(model)

Mthod for initializing the given model. This method uses truncated normal distribution for initializing the position embedding as well as the class token, and, the head of the model is initialized using He initialization.

VisRecTrans.get_callback[source]

VisRecTrans.get_callback()

Method for getting the callback to train the embedding block of the model. It is highly recommended to use the callback, returned by this method, while training a ViT model.

Let's see if this class is working well :

vis_rec_ob = VisRecTrans('vit_small_patch16_224', 10, False)
model_test = vis_rec_ob.create_model()
vis_rec_ob.initialize(model_test)
assert isinstance(model_test, nn.Sequential)

As we see, the model is a sequential list of layers, and can be used with the Learner class of fastai, as we use any other model.

The list of models supported by the VisRecTrans class :

VisRecTrans.models_list
['vit_large_patch16_224',
 'vit_large_patch16_224_in21k',
 'vit_huge_patch14_224_in21k',
 'vit_small_patch16_224',
 'vit_small_patch16_224_in21k']