In this tutorial, we will pre-train a ViT model, using SimCLR, and then, we shall fine-tune the model for a downstream visual recognition task. For a detailed and thorough explanation, it is recommended to first go through the tutorial notebooks of this library, and then this notebook.

Open In Colab

Imports

from fastai.vision.all import *
from transcv.visrectrans import *
torch.backends.cudnn.benchmark = True
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.simclr import *

Datasets & Dataloaders for self-supervised pre-training

def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    files = get_image_files(source)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls
bs, resize, size = 8, 256, 224
dls = get_dls(resize, bs)
100.00% [2900353024/2900347689 01:35<00:00]

ViT model

Getting the ViT model, using the VisRecTrans class :

vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:388: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:426: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")

We need a module for extracting the class token, from the feature representation computed by the ViT :

class Encoder_transcv (Module) :
  def __init__ (self) :
    pass 
  def forward (self, x) : 
    x = x[:, 0]
    return x
encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
def create_simclr_model(encoder, hidden_size=256, projection_size=128, bn=False, nlayers=2):
    "Create SimCLR model"
    n_in  = in_channels(encoder)
    with torch.no_grad(): representation = encoder(torch.randn((2,n_in,224,224)))
    projector = create_mlp_module(representation.size(1), hidden_size, projection_size, bn=bn, nlayers=nlayers)
    apply_init(projector)
    return nn.Sequential(encoder, projector)
model = create_simclr_model(encoder)
aug_pipelines = get_simclr_aug_pipelines(size, rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False) 
cbs=[SimCLR(aug_pipelines)]
learn = Learner(dls, model, cbs=cbs)
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.sim_clr.show(n=5);
learn.to_fp16();
lr,wd,epochs=1e-2,1e-2,100
learn.lr_find()
SuggestedLRs(valley=0.0002290867705596611)

Since we have a different model for SimCLR, we need a custom embed_callback callback :

class TrainEmbedCallback (Callback) :
  def before_train (self) :
    self.learn.model[0][1].training = True
    self.learn.model[0][1].requires_grad_(True)
  def before_validation (self) :
    self.learn.model[0][1].training = False
    self.learn.model[0][1].requires_grad_(False)
learn.summary()
Sequential (Input shape: 8)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     8 x 384 x 14 x 14   
Conv2d                                    295296     True      
Identity                                                       
EmbedBlock                                                     
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
Identity                                                       
Encoder_transcv                                                
____________________________________________________________________________
                     8 x 256             
Linear                                    98560      True      
ReLU                                                           
____________________________________________________________________________
                     8 x 128             
Linear                                    32896      True      
____________________________________________________________________________

Total params: 21,721,088
Total trainable params: 21,721,088
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7f2235174830>
Loss function: <bound method SimCLR.lf of SimCLR>

Model unfrozen

Callbacks:
  - TrainEvalCallback
  - SimCLR
  - MixedPrecision
  - Recorder
  - ProgressCallback

Self-supervised pre-training

Here, we pre-train the ViT model for just 1 epoch, since the training time for even 1 epoch is significantly high. It is recommended to pre-train the ViT model for 100 epochs (according to this tutorial).

learn.unfreeze()
learn.fit_flat_cos(1, 1e-4, wd=wd, pct_start=0.5, cbs = [TrainEmbedCallback])
epoch train_loss valid_loss time
0 1.516793 1.509312 46:02

Saving the parameters of the encoder :

save_name = f'simclr_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
encoder_path = learn.path/learn.model_dir/f'{save_name}_encoder.pth'
torch.save(learn.model[0].state_dict(), encoder_path)
bs, size
(8, 224)

Datasets & Dataloaders for fine-tuning

def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

Getting the learner for fine-tuning

def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(size=size, encoder_path="models/swav_iwang_sz128_epc100_encoder.pth"):
    
    dls = get_dls(size, bs=bs)
    pretrained_encoder = torch.load(encoder_path)
    vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
    model = vis_rec_ob.create_model()
    vis_rec_ob.initialize(model)
    encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, splitter=split_func,
                metrics=top_k_accuracy, loss_func=LabelSmoothingCrossEntropy())
    return learn
learn = create_learner(size, encoder_path)
learn.unfreeze()
learn.lr_find()
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:388: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:426: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
SuggestedLRs(valley=0.0008317637839354575)

Fine-tuning

learn.fit_one_cycle(5, 1e-4, cbs = TrainEmbedCallback)
epoch train_loss valid_loss top_k_accuracy time
0 2.593372 3.329553 0.185798 06:50
1 2.247928 3.347796 0.293713 06:49
2 1.952943 3.466967 0.238483 06:50
3 1.840052 3.220397 0.405447 06:49
4 1.861218 3.188895 0.435225 06:50

We can see that the ViT model has a low top_k_accuracy upon fine-tuning. It is because the model was pre-trained for just 1 epoch.