Skip to content

Commit

Permalink
修复global_state没有更新的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jul 24, 2020
1 parent ed9737b commit 0e11758
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tools/det_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ def train(net, optimizer, loss_func, train_loader, eval_loader, to_use_device,
start = time.time()
global_step += 1
logger.info(f'train_loss: {train_loss / len(train_loader)}')
global_state['start_epoch'] = epoch
global_state['best_model'] = best_model
global_state['global_step'] = global_step
if (epoch + 1) % train_options['val_interval'] == 0:
global_state['start_epoch'] = epoch
global_state['best_model'] = best_model
global_state['global_step'] = global_step
net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
if train_options['ckpt_save_type'] == 'HighestAcc':
Expand Down
3 changes: 3 additions & 0 deletions tools/rec_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_us
f"time:{interval_batch_time:.4f}")
start = time.time()
if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
global_state['start_epoch'] = epoch
global_state['best_model'] = best_model
global_state['global_step'] = global_step
net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
if train_options['ckpt_save_type'] == 'HighestAcc':
Expand Down

0 comments on commit 0e11758

Please sign in to comment.