In this tutorial, we will pre-train a ViT model, using SwAV, and then, we shall fine-tune the model for a downstream visual recognition task. For a detailed and thorough explanation, it is recommended to first go through the tutorial notebooks of this library, and then this notebook.

Open In Colab

Imports

from fastai.vision.all import *
from transcv.visrectrans import *
torch.backends.cudnn.benchmark = True
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *

Datasets & Dataloaders for self-supervised pre-training

def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    
    files = get_image_files(source)
    tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
    
    batch_tfms = [IntToFloatTensor]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls
bs, resize, size = 8, 256, 224

Set queue size, it needs to be a multiple of batch size.

K = bs*2**4

ViT model

Getting the ViT model, using the VisRecTrans class :

vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:388: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:426: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
class Encoder_transcv (Module) :
  def __init__ (self) :
    pass 
  def forward (self, x) : 
    x = x[:, 0]
    return x
encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
def create_swav_model(encoder, hidden_size=256, projection_size=128, n_protos=3000, bn=True, nlayers=2):
    "Create SwAV model"
    n_in  = in_channels(encoder)
    with torch.no_grad(): representation = encoder(torch.randn((2,n_in,224,224)))
    projector = create_mlp_module(representation.size(1), hidden_size, projection_size, bn=bn, nlayers=nlayers)
    prototypes = nn.Linear(projection_size, n_protos, bias=False)
    apply_init(projector)
    with torch.no_grad():
        w = prototypes.weight.data.clone()
        prototypes.weight.copy_(F.normalize(w))
    return SwAVModel(encoder, projector, prototypes)

Initialize the Dataloaders using the function above.

dls = get_dls(resize, bs)
100.00% [2900353024/2900347689 01:21<00:00]
model = create_swav_model(encoder)
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
                                       crop_sizes=[size,int(size)], 
                                       min_scales=[0.25,0.2],
                                       max_scales=[1.0,0.35],
                                       rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False) 
cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]
learn = Learner(dls, model, cbs=cbs)
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);
learn.to_fp16();
learn.summary()
SwAVModel (Input shape: 8)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     8 x 384 x 14 x 14   
Conv2d                                    295296     True      
Identity                                                       
EmbedBlock                                                     
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1152      
Linear                                    443520     True      
Dropout                                                        
Linear                                    147840     True      
Dropout                                                        
Identity                                                       
LayerNorm                                 768        True      
____________________________________________________________________________
                     8 x 197 x 1536      
Linear                                    591360     True      
GELU                                                           
____________________________________________________________________________
                     8 x 197 x 384       
Linear                                    590208     True      
Dropout                                                        
LayerNorm                                 768        True      
Identity                                                       
Encoder_transcv                                                
____________________________________________________________________________
                     8 x 256             
Linear                                    98560      True      
BatchNorm1d                               512        True      
ReLU                                                           
____________________________________________________________________________
                     8 x 128             
Linear                                    32896      True      
____________________________________________________________________________
                     8 x 3000            
Linear                                    384000     True      
____________________________________________________________________________

Total params: 22,105,600
Total trainable params: 22,105,600
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7ff4a686ab90>
Loss function: <bound method SWAV.lf of SWAV>

Callbacks:
  - TrainEvalCallback
  - SWAV
  - MixedPrecision
  - Recorder
  - ProgressCallback
lr,wd,epochs=1e-2,1e-2,100
learn.lr_find()
SuggestedLRs(valley=0.001737800776027143)

Creating the custom embed_callback :

class TrainEmbedCallback (Callback) :
  def before_train (self) :
    self.learn.model.encoder.training = True
    self.learn.model.encoder.requires_grad_(True)
  def before_validation (self) :
    self.learn.model.encoder.training = False
    self.learn.model.encoder.requires_grad_(False)

Self-supervised pre-training

Here, we pre-train the ViT model for just 1 epoch, since the training time for even 1 epoch is significantly high. It is recommended to pre-train the ViT model for 100 epochs (according to this tutorial).

learn.unfreeze()
learn.fit_one_cycle(1, 1e-3, pct_start=0.5, cbs = [TrainEmbedCallback, GradientAccumulation])
epoch train_loss valid_loss time
0 7.998878 7.999933 2:43:50

Saving the parameters of the encoder :

save_name = f'swav_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
encoder_path = learn.path/learn.model_dir/f'{save_name}_encoder.pth'
torch.save(learn.model.encoder.state_dict(), encoder_path)
learn.recorder.plot_loss()

Datasets & Dataloaders for fine-tuning

bs, size
(8, 224)
def get_dls(size, bs, workers=None):
    path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
    source = untar_data(path)
    files = get_image_files(source, folders=['train', 'val'])
    splits = GrandparentSplitter(valid_name='val')(files)
    
    item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    tfms = [[PILImage.create, ToTensor, *item_aug], 
            [parent_label, Categorize()]]
    
    dsets = Datasets(files, tfms=tfms, splits=splits)
    
    batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
    dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
    return dls

Getting the learner for fine-tuning

def split_func(m): return L(m[0], m[1]).map(params)

def create_learner(encoder_path, size=size, arch='xresnet34'):
    
    dls = get_dls(size, bs=bs)
    pretrained_encoder = torch.load(encoder_path)
    vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
    model = vis_rec_ob.create_model()
    vis_rec_ob.initialize(model)
    encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
    encoder.load_state_dict(pretrained_encoder)
    nf = encoder(torch.randn(2,3,224,224)).size(-1)
    classifier = create_cls_module(nf, dls.c)
    model = nn.Sequential(encoder, classifier)
    learn = Learner(dls, model, splitter=split_func,
                metrics=top_k_accuracy, loss_func=LabelSmoothingCrossEntropy())
    return learn
learn = create_learner(encoder_path, size)
learn.unfreeze()
learn.lr_find()
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:388: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
/usr/local/lib/python3.7/dist-packages/torch/nn/init.py:426: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
SuggestedLRs(valley=0.0005754399462603033)
class TrainEmbedFinetuneCallback (Callback) :
  def before_train (self) :
    self.learn.model[0][1].training = True
    self.learn.model[0][1].requires_grad_(True)
  def before_validation (self) :
    self.learn.model[0][1].training = False
    self.learn.model[0][1].requires_grad_(False)

Fine-tuning

learn.fit_one_cycle(5, 1e-3, cbs = [TrainEmbedFinetuneCallback, GradientClip])
epoch train_loss valid_loss top_k_accuracy time
0 2.692342 5.109475 0.010690 07:34
1 2.605555 4.924060 0.001018 07:30
2 2.442838 4.045727 0.001018 07:30
3 2.251731 3.829627 0.022143 07:30
4 2.148555 3.695153 0.043523 07:30

We can see that the ViT model has a very low top_k_accuracy upon fine-tuning. It is because the model was pre-trained for just 1 epoch.