Skip to content

Commit

Permalink
linear eval rn50
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed Jul 15, 2021
1 parent 5783b4b commit 9085367
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
51 changes: 36 additions & 15 deletions eval_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.backends.cudnn as cudnn
from torchvision import datasets
from torchvision import transforms as pth_transforms
from torchvision import models as torchvision_models

import utils
import vision_transformer as vits
Expand Down Expand Up @@ -65,14 +66,29 @@ def eval_linear(args):
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")

# ============ building network ... ============
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if args.arch in vits.__dict__.keys():
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
# if the network is a XCiT
elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0)
embed_dim = model.embed_dim
# otherwise, we check if the architecture is in torchvision models
elif args.arch in torchvision_models.__dict__.keys():
model = torchvision_models.__dict__[args.arch]()
embed_dim = model.fc.weight.shape[1]
model.fc = nn.Identity()
else:
print(f"Unknow architecture: {args.arch}")
sys.exit(1)
model.cuda()
model.eval()
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
# load weights to evaluate
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
print(f"Model {args.arch} built.")

linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels)
linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels)
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

Expand Down Expand Up @@ -139,11 +155,14 @@ def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):

# forward
with torch.no_grad():
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
else:
output = model(inp)
output = linear_classifier(output)

# compute cross entropy loss
Expand Down Expand Up @@ -178,11 +197,14 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool):

# forward
with torch.no_grad():
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = [x[:, 0] for x in intermediate_output]
if avgpool:
output.append(torch.mean(intermediate_output[-1][:, 1:], dim=1))
output = torch.cat(output, dim=-1)
else:
output = model(inp)
output = linear_classifier(output)
loss = nn.CrossEntropyLoss()(output, target)

Expand Down Expand Up @@ -229,8 +251,7 @@ def forward(self, x):
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
We typically set this to False for ViT-Small and to True with ViT-Base.""")
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
Expand Down
2 changes: 2 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam
url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
elif model_name == "xcit_medium_24_p8":
url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
elif model_name == "resnet50":
url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
if url is not None:
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
Expand Down

0 comments on commit 9085367

Please sign in to comment.