Skip to content

Commit

Permalink
Refactor print progress of ImageNet training (pytorch#529)
Browse files Browse the repository at this point in the history
* refactored with lr_scheduler

* Revert "refactored with lr_scheduler"

This reverts commit 847cd87.

* refactored printing of progress
  • Loading branch information
Philip Meier authored and soumith committed Mar 27, 2019
1 parent 3b349ad commit 27a6244
Showing 1 changed file with 40 additions and 25 deletions.
65 changes: 40 additions & 25 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,13 @@ def main_worker(gpu, ngpus_per_node, args):


def train(train_loader, model, criterion, optimizer, epoch, args):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1,
top5, prefix="Epoch: [{}]".format(epoch))

# switch to train mode
model.train()
Expand Down Expand Up @@ -295,21 +297,16 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
progress.print(i)


def validate(val_loader, model, criterion, args):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
prefix='Test: ')

# switch to evaluate mode
model.eval()
Expand All @@ -336,14 +333,9 @@ def validate(val_loader, model, criterion, args):
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
progress.print(i)

# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

Expand All @@ -358,7 +350,9 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):

class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()

def reset(self):
Expand All @@ -373,6 +367,27 @@ def update(self, val, n=1):
self.count += n
self.avg = self.sum / self.count

def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
def __init__(self, num_batches, *meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix

def print(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))

def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
Expand Down

0 comments on commit 27a6244

Please sign in to comment.