Skip to content

Commit

Permalink
Add num_labels to eval_linear, change max_accuracy to best_acc
Browse files Browse the repository at this point in the history
  • Loading branch information
user1234554321 authored and Mathilde Caron committed May 13, 2021
1 parent 0466678 commit c66329a
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions eval_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def eval_linear(args):
# load weights to evaluate
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)))
linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)), num_labels=args.num_labels)
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])

Expand Down Expand Up @@ -112,7 +112,7 @@ def eval_linear(args):
}
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
print("Training of the supervised linear classifier on frozen features completed.\n"
"Top-1 test accuracy: {acc:.1f}".format(acc=max_accuracy))
"Top-1 test accuracy: {acc:.1f}".format(acc=best_acc))


def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
Expand Down Expand Up @@ -165,21 +165,30 @@ def validate_network(val_loader, model, linear_classifier, n, avgpool):
output = linear_classifier(output)
loss = nn.CrossEntropyLoss()(output, target)

acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if linear_classifier.module.num_labels >= 5:
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
else:
acc1, = utils.accuracy(output, target, topk=(1,))

batch_size = inp.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
if linear_classifier.module.num_labels >= 5:
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
if linear_classifier.module.num_labels >= 5:
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
else:
print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


class LinearClassifier(nn.Module):
"""Linear layer to train on top of frozen features"""
def __init__(self, dim, num_labels=1000):
super(LinearClassifier, self).__init__()
self.num_labels = num_labels
self.linear = nn.Linear(dim, num_labels)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()
Expand Down Expand Up @@ -217,5 +226,6 @@ def forward(self, x):
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier')
args = parser.parse_args()
eval_linear(args)

0 comments on commit c66329a

Please sign in to comment.