In this tutorial, we will pre-train a ViT model, using SwAV, and then, we shall fine-tune the model for a downstream visual recognition task. For a detailed and thorough explanation, it is recommended to first go through the tutorial notebooks of this library, and then this notebook.
from fastai.vision.all import *
from transcv.visrectrans import *
torch.backends.cudnn.benchmark = True
from self_supervised.augmentations import *
from self_supervised.layers import *
from self_supervised.vision.swav import *
def get_dls(size, bs, workers=None):
path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
source = untar_data(path)
files = get_image_files(source)
tfms = [[PILImage.create, ToTensor, RandomResizedCrop(size, min_scale=1.)],
[parent_label, Categorize()]]
dsets = Datasets(files, tfms=tfms, splits=RandomSplitter(valid_pct=0.1)(files))
batch_tfms = [IntToFloatTensor]
dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
return dls
bs, resize, size = 8, 256, 224
Set queue size, it needs to be a multiple of batch size.
K = bs*2**4
ViT model
Getting the ViT model, using the VisRecTrans
class :
vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
class Encoder_transcv (Module) :
def __init__ (self) :
pass
def forward (self, x) :
x = x[:, 0]
return x
encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
def create_swav_model(encoder, hidden_size=256, projection_size=128, n_protos=3000, bn=True, nlayers=2):
"Create SwAV model"
n_in = in_channels(encoder)
with torch.no_grad(): representation = encoder(torch.randn((2,n_in,224,224)))
projector = create_mlp_module(representation.size(1), hidden_size, projection_size, bn=bn, nlayers=nlayers)
prototypes = nn.Linear(projection_size, n_protos, bias=False)
apply_init(projector)
with torch.no_grad():
w = prototypes.weight.data.clone()
prototypes.weight.copy_(F.normalize(w))
return SwAVModel(encoder, projector, prototypes)
Initialize the Dataloaders using the function above.
dls = get_dls(resize, bs)
model = create_swav_model(encoder)
aug_pipelines = get_swav_aug_pipelines(num_crops=[2,6],
crop_sizes=[size,int(size)],
min_scales=[0.25,0.2],
max_scales=[1.0,0.35],
rotate=True, rotate_deg=10, jitter=True, bw=True, blur=False)
cbs=[SWAV(aug_pipelines, crop_assgn_ids=[0,1], K=K, queue_start_pct=0.5, temp=0.1)]
learn = Learner(dls, model, cbs=cbs)
b = dls.one_batch()
learn._split(b)
learn('before_batch')
learn.swav.show(n=5);
learn.to_fp16();
learn.summary()
lr,wd,epochs=1e-2,1e-2,100
learn.lr_find()
Creating the custom embed_callback
:
class TrainEmbedCallback (Callback) :
def before_train (self) :
self.learn.model.encoder.training = True
self.learn.model.encoder.requires_grad_(True)
def before_validation (self) :
self.learn.model.encoder.training = False
self.learn.model.encoder.requires_grad_(False)
Self-supervised pre-training
Here, we pre-train the ViT model for just 1 epoch, since the training time for even 1 epoch is significantly high. It is recommended to pre-train the ViT model for 100 epochs (according to this tutorial).
learn.unfreeze()
learn.fit_one_cycle(1, 1e-3, pct_start=0.5, cbs = [TrainEmbedCallback, GradientAccumulation])
Saving the parameters of the encoder :
save_name = f'swav_iwang_sz{size}_epc{epochs}'
learn.save(save_name)
encoder_path = learn.path/learn.model_dir/f'{save_name}_encoder.pth'
torch.save(learn.model.encoder.state_dict(), encoder_path)
learn.recorder.plot_loss()
bs, size
def get_dls(size, bs, workers=None):
path = URLs.IMAGEWANG_160 if size <= 160 else URLs.IMAGEWANG
source = untar_data(path)
files = get_image_files(source, folders=['train', 'val'])
splits = GrandparentSplitter(valid_name='val')(files)
item_aug = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
tfms = [[PILImage.create, ToTensor, *item_aug],
[parent_label, Categorize()]]
dsets = Datasets(files, tfms=tfms, splits=splits)
batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
dls = dsets.dataloaders(bs=bs, num_workers=workers, after_batch=batch_tfms)
return dls
def split_func(m): return L(m[0], m[1]).map(params)
def create_learner(encoder_path, size=size, arch='xresnet34'):
dls = get_dls(size, bs=bs)
pretrained_encoder = torch.load(encoder_path)
vis_rec_ob = VisRecTrans('vit_small_patch16_224', 0, False)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
encoder = nn.Sequential(model[0], model[1], model[2], Encoder_transcv())
encoder.load_state_dict(pretrained_encoder)
nf = encoder(torch.randn(2,3,224,224)).size(-1)
classifier = create_cls_module(nf, dls.c)
model = nn.Sequential(encoder, classifier)
learn = Learner(dls, model, splitter=split_func,
metrics=top_k_accuracy, loss_func=LabelSmoothingCrossEntropy())
return learn
learn = create_learner(encoder_path, size)
learn.unfreeze()
learn.lr_find()
class TrainEmbedFinetuneCallback (Callback) :
def before_train (self) :
self.learn.model[0][1].training = True
self.learn.model[0][1].requires_grad_(True)
def before_validation (self) :
self.learn.model[0][1].training = False
self.learn.model[0][1].requires_grad_(False)
learn.fit_one_cycle(5, 1e-3, cbs = [TrainEmbedFinetuneCallback, GradientClip])
We can see that the ViT model has a very low top_k_accuracy upon fine-tuning. It is because the model was pre-trained for just 1 epoch.