Skip to content

Commit

Permalink
record status
Browse files Browse the repository at this point in the history
  • Loading branch information
tangminji committed Aug 2, 2023
1 parent 2541ab8 commit 22f9218
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
18 changes: 15 additions & 3 deletions main_ce1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def init_args():
parser.add_argument('--params_path', type=str, default='') #params.json
parser.add_argument('--out_tmp', type=str, default='') #result.json
parser.add_argument('--nrun', action='store_true')
parser.add_argument('--record', action='store_true')
args = parser.parse_args()
return args

Expand All @@ -89,7 +90,8 @@ def update_args(params={}):
args.result_dir,
args.dataset,
noise_level,
'{}_{}{}/'.format(args.model_type, args.noise_type,
'{}_{}{}{}/'.format(args.model_type, args.noise_type,
f'_lr{args.lr}_bs{args.batch_size}_wd{args.weight_decay}',
'' if args.model_type == 'ce' else '_{}_J={}_{}_lam={}_wm={}_del={}_eps={}_eta={}_inc={}{}'.format(args.filter,
args.J,
args.f_type,
Expand Down Expand Up @@ -162,6 +164,11 @@ def main(args, params={}):
net_record = torch.zeros([rollwin, len(train_loader.dataset), args.c+1])
delta_smooth = torch.full((len(train_loader.dataset),), args.delta)

# TODO
# 记录net_record
if args.record:
records = torch.zeros([args.n_epoch, len(train_loader.dataset), args.c+1])

for epoch in range(0, args.n_epoch):
if args.model_type == 'ce':
train_loss, train_acc = train_ce(args, net, train_loader, optimizer, epoch, scheduler, criterion_train)
Expand All @@ -177,7 +184,9 @@ def main(args, params={}):
train_loader.dataset.update_corrupted_label(y_cor.cpu().numpy())
else:
assert False, "Check model type, which should be in [ce, ours, ours_cl]~"


if args.record:
records[epoch] = net_record[epoch % args.rollWindow]
# validation
val_best, val_loss, val_acc = evaluate(args, net, val_loader, epoch, criterion_val, val_best)
# evaluate models
Expand All @@ -187,6 +196,8 @@ def main(args, params={}):
val_acc_list.append(val_acc)
test_acc_list.append(test_acc)

if args.record:
torch.save(records, os.path.join(args.log_dir, "record.pt"))
# save model at the last epoch
checkpoint(val_acc, epoch, net, args.log_dir, last=True)
run_time = time.time() - global_t0
Expand Down Expand Up @@ -244,5 +255,6 @@ def main(args, params={}):
res = main(args, params=params)
# TODO
if args.out_tmp:
res['ITERATION'] = params['ITERATION']
if 'ITERATION' in params:
res['ITERATION'] = params['ITERATION']
json.dump(res, open(args.out_tmp, "w+", encoding="utf-8"), ensure_ascii=False)
5 changes: 5 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def train_ours(args, model, loader, optimizer, epoch, scheduler, criterion, net_
correct.update(acc1[0].item(), data.size(0))

scheduler.step()
log_value('train/lr', optimizer.param_groups[0]['lr'], step=epoch)
log_value('train_detail/delta_smooth/avg', delta_smooth.mean(), step=epoch)
log_value('train_detail/delta_smooth/std', delta_smooth.std(), step=epoch)
log_value('train_detail/net_record/avg', net_record.mean(), step=epoch)
log_value('train_detail/net_record/std', net_record.std(), step=epoch)
# Print and log stats for the epoch
log_value('train/loss', train_loss.avg, step=epoch)
log(args.logpath, 'Time for Train-Epoch-{}/{}:{:.1f}s Acc:{}, Loss:{}\n'.format(epoch, args.n_epoch, time.time() - t0, correct.avg, train_loss.avg))
Expand Down

0 comments on commit 22f9218

Please sign in to comment.