Skip to content

Commit

Permalink
modeel weight pushing
Browse files Browse the repository at this point in the history
  • Loading branch information
ihounie committed Mar 16, 2023
1 parent ef32c25 commit 1bee214
Showing 1 changed file with 7 additions and 40 deletions.
47 changes: 7 additions & 40 deletions TrivialAugment/train_pd_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,54 +367,21 @@ def train_and_eval(rank, worldsize, tag, dataroot, test_ratio=0.0, cv_fold=0, re
rs['testtrain'] = run_epoch(rank, worldsize, model, testtrainloader_, criterion, None, desc_default='testtrain', epoch=epoch, writer=writers[3], verbose=True)
rs['test'] = run_epoch(rank, worldsize, model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=True)
rs['valid'] = run_epoch(rank, worldsize, model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=True)

if metric == 'last' or rs[metric]['top1'] > best_top1:
if metric != 'last':
best_top1 = rs[metric]['top1']
for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'test', 'testtrain']):
if setname in rs and key in rs[setname]:
result['%s_%s' % (key, setname)] = rs[setname][key]
result['epoch'] = epoch

writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch)
writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch)

reporter(
loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'],
loss_test=rs['test']['loss'], top1_test=rs['test']['top1']
)

# save checkpoint
if save_path and C.get().get('save_model', True) and (worldsize <= 1 or torch.distributed.get_rank() == 0):
logger.info('save model@%d to %s' % (epoch, save_path))
torch.save({
'epoch': epoch,
'log': {
'train': rs['train'].get_dict(),
'test': rs['test'].get_dict(),
},
'optimizer': optimizer.state_dict(),
'model': model.state_dict()
}, save_path)
torch.save({
'epoch': epoch,
'log': {
'train': rs['train'].get_dict(),
'test': rs['test'].get_dict(),
},
'optimizer': optimizer.state_dict(),
'model': model.state_dict()
}, save_path.replace('.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth'))

early_finish_epoch = C.get().get('early_finish_epoch', None)
if early_finish_epoch == epoch:
break
if wandb_log:
if epoch == max_epoch:
if epoch == max_epoch:
final_dict = {f"final {k}": v for k,v in result.items()}
wandb.log(final_dict)
if save_path and C.get().get('save_model', True):
wandb.save()
#wandb.save()
print('save model@%d to %s' % (epoch, save_path))
torch.save(model.state_dict(), save_path)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(save_path)
wandb.log_artifact(artifact)
else:
wandb.log({"train": rs["train"].get_dict(), "epoch":epoch, "dualvar": dual_vars})
if epoch % log_interval == 0:
Expand Down

0 comments on commit 1bee214

Please sign in to comment.