Skip to content

Commit

Permalink
Update checkpoint.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware authored Apr 9, 2021
1 parent aa3ab7f commit dbec79e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions kospeech/checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def save(self):
'optimizer': self.optimizer,
'trainset_list': self.trainset_list,
'validset': self.validset,
'epoch': self.epoch
'epoch': self.epoch,
}
torch.save(trainer_states, os.path.join(os.getcwd(), self.TRAINER_STATE_NAME))
torch.save(self.model, os.path.join(os.getcwd(), self.MODEL_NAME))
Expand Down Expand Up @@ -111,9 +111,11 @@ def load(self, path):
model.flatten_parameters()

return Checkpoint(
model=model, optimizer=resume_checkpoint['optimizer'], epoch=resume_checkpoint['epoch'],
model=model,
optimizer=resume_checkpoint['optimizer'],
epoch=resume_checkpoint['epoch'],
trainset_list=resume_checkpoint['trainset_list'],
validset=resume_checkpoint['validset']
validset=resume_checkpoint['validset'],
)

def get_latest_checkpoint(self):
Expand Down

0 comments on commit dbec79e

Please sign in to comment.