In this tutorial, we will create a ViT and use it for visual recognition.

Open In Colab

Imports

from fastai.vision.all import *
from transcv.visrectrans import VisRecTrans

Datasets & Dataloaders

path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224), bs = 8)
dls.show_batch(max_n=9, figsize=(6,7))

ViT model

We create a ViT-Large model using the VisRecTrans class.

vis_rec_ob = VisRecTrans('vit_large_patch16_224', dls.c, True)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
embed_callback = vis_rec_ob.get_callback()

We need a custom split function, in order to have a frozen, pre-trained body and an unfrozen, initialised head.

def ViT_split (m) : return L(model[0], model[1], model[2], model[3]).map(params)
learn = Learner(dls, model, splitter = ViT_split, loss_func = CrossEntropyLossFlat(), metrics = accuracy).to_fp16()
learn.freeze()
learn.summary()
Sequential (Input shape: 8)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     8 x 1024 x 14 x 14  
Conv2d                                    787456     False     
Identity                                                       
EmbedBlock                                                     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 3072      
Linear                                    3148800    False     
Dropout                                                        
Linear                                    1049600    False     
Dropout                                                        
Identity                                                       
LayerNorm                                 2048       True      
____________________________________________________________________________
                     8 x 197 x 4096      
Linear                                    4198400    False     
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 1024      
Linear                                    4195328    False     
Dropout                                                        
LayerNorm                                 2048       True      
Identity                                                       
____________________________________________________________________________
                     8 x 2               
Linear                                    2050       True      
____________________________________________________________________________

Total params: 303,100,930
Total trainable params: 102,402
Total non-trainable params: 302,998,528

Optimizer used: <function Adam at 0x7f36cbd18b00>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #3

Callbacks:
  - TrainEvalCallback
  - MixedPrecision
  - Recorder
  - ProgressCallback
learn.lr_find()
SuggestedLRs(valley=0.0003981071640737355)

Training & Evaluation

NB : It is recommended to always use the embed_callback, when training a ViT model.

learn.fit_one_cycle(1, 2e-4, cbs = [embed_callback, GradientClip])
epoch train_loss valid_loss accuracy time
0 0.595548 0.593401 0.694858 34:27
learn.save('stage-1')
Path('/root/.fastai/data/oxford-iiit-pet/images/models/stage-1.pth')
learn.unfreeze()
learn.lr_find()
SuggestedLRs(valley=3.630780702224001e-05)
learn.fit_one_cycle(1, 1e-5, cbs = [embed_callback, GradientClip])
epoch train_loss valid_loss accuracy time
0 0.561189 0.559480 0.700271 46:17
learn.save('stage-2')
Path('/root/.fastai/data/oxford-iiit-pet/images/models/stage-2.pth')