In this tutorial, we will create a ViT and use it for visual recognition.
from fastai.vision.all import *
from transcv.visrectrans import VisRecTrans
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2,
label_func=lambda x: x[0].isupper(), item_tfms=Resize(224), bs = 8)
dls.show_batch(max_n=9, figsize=(6,7))
vis_rec_ob = VisRecTrans('vit_large_patch16_224', dls.c, True)
model = vis_rec_ob.create_model()
vis_rec_ob.initialize(model)
embed_callback = vis_rec_ob.get_callback()
We need a custom split function, in order to have a frozen, pre-trained body and an unfrozen, initialised head.
def ViT_split (m) : return L(model[0], model[1], model[2], model[3]).map(params)
learn = Learner(dls, model, splitter = ViT_split, loss_func = CrossEntropyLossFlat(), metrics = accuracy).to_fp16()
learn.freeze()
learn.summary()
Sequential (Input shape: 8) ============================================================================ Layer (type) Output Shape Param # Trainable ============================================================================ 8 x 1024 x 14 x 14 Conv2d 787456 False Identity EmbedBlock Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 3072 Linear 3148800 False Dropout Linear 1049600 False Dropout Identity LayerNorm 2048 True ____________________________________________________________________________ 8 x 197 x 4096 Linear 4198400 False GELU ____________________________________________________________________________ 8 x 197 x 1024 Linear 4195328 False Dropout LayerNorm 2048 True Identity ____________________________________________________________________________ 8 x 2 Linear 2050 True ____________________________________________________________________________ Total params: 303,100,930 Total trainable params: 102,402 Total non-trainable params: 302,998,528 Optimizer used: <function Adam at 0x7f36cbd18b00> Loss function: FlattenedLoss of CrossEntropyLoss() Model frozen up to parameter group #3 Callbacks: - TrainEvalCallback - MixedPrecision - Recorder - ProgressCallback
learn.lr_find()
SuggestedLRs(valley=0.0003981071640737355)
NB : It is recommended to always use the embed_callback
, when training a ViT model.
learn.fit_one_cycle(1, 2e-4, cbs = [embed_callback, GradientClip])
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.595548 | 0.593401 | 0.694858 | 34:27 |
learn.save('stage-1')
Path('/root/.fastai/data/oxford-iiit-pet/images/models/stage-1.pth')
learn.unfreeze()
learn.lr_find()
SuggestedLRs(valley=3.630780702224001e-05)
learn.fit_one_cycle(1, 1e-5, cbs = [embed_callback, GradientClip])
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.561189 | 0.559480 | 0.700271 | 46:17 |
learn.save('stage-2')
Path('/root/.fastai/data/oxford-iiit-pet/images/models/stage-2.pth')