Skip to content

Commit

Permalink
net_size is now part of a model checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
kainoj committed Nov 28, 2018
1 parent 9ab3c06 commit a2639ec
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def save_checkpoint(self, epoch):
'epoch': epoch,
'model_state_dict': self.net.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'losses': self.loss_history
'losses': self.loss_history,
'net_size': self.net_size
}, full_path)

self.current_model_name = full_path
Expand All @@ -233,6 +234,7 @@ def load_checkpoint(self, model_checkpoint):
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.loss_history = checkpoint['losses']
self.start_epoch = checkpoint['epoch'] + 1
self.net_size = checkpoint['net_size']
self.current_model_name = model_checkpoint


Expand Down

0 comments on commit a2639ec

Please sign in to comment.