From c66329a0397fbc775cd7caf80a4125ca45511711 Mon Sep 17 00:00:00 2001 From: Syed Adeel Date: Tue, 4 May 2021 20:12:49 +0000 Subject: [PATCH] Add num_labels to eval_linear, change max_accuracy to best_acc --- eval_linear.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/eval_linear.py b/eval_linear.py index b17cf7309..649b95788 100644 --- a/eval_linear.py +++ b/eval_linear.py @@ -60,7 +60,7 @@ def eval_linear(args): # load weights to evaluate utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) - linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))) + linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels) linear_classifier = linear_classifier.cuda() linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) @@ -112,7 +112,7 @@ def eval_linear(args): } torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) print("Training of the supervised linear classifier on frozen features completed.\n" - "Top-1 test accuracy: {acc:.1f}".format(acc=max_accuracy)) + "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): @@ -165,14 +165,22 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool): output = linear_classifier(output) loss = nn.CrossEntropyLoss()(output, target) - acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + if linear_classifier.module.num_labels >= 5: + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + else: + acc1, = utils.accuracy(output, target, topk=(1,)) batch_size = inp.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + if linear_classifier.module.num_labels >= 5: + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + if linear_classifier.module.num_labels >= 5: + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + else: + print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, losses=metric_logger.loss)) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @@ -180,6 +188,7 @@ class LinearClassifier(nn.Module): """Linear layer to train on top of frozen features""" def __init__(self, dim, num_labels=1000): super(LinearClassifier, self).__init__() + self.num_labels = num_labels self.linear = nn.Linear(dim, num_labels) self.linear.weight.data.normal_(mean=0.0, std=0.01) self.linear.bias.data.zero_() @@ -217,5 +226,6 @@ def forward(self, x): parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') + parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') args = parser.parse_args() eval_linear(args)