Skip to content

Commit

Permalink
Merge pull request facebookresearch#144 from facebookresearch/parallel
Browse files Browse the repository at this point in the history
fix for saving models trained with DataParallel
  • Loading branch information
ajfisch authored May 17, 2018
2 parents 3b19410 + c3a19f8 commit 50d0e49
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions drqa/reader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,11 @@ def decode_candidates(score_s, score_e, candidates, top_n=1, max_len=None):
# --------------------------------------------------------------------------

def save(self, filename):
state_dict = copy.copy(self.network.state_dict())
if self.parallel:
network = self.network.module
else:
network = self.network
state_dict = copy.copy(network.state_dict())
if 'fixed_embedding' in state_dict:
state_dict.pop('fixed_embedding')
params = {
Expand All @@ -409,8 +413,12 @@ def save(self, filename):
logger.warning('WARN: Saving failed... continuing anyway.')

def checkpoint(self, filename, epoch):
if self.parallel:
network = self.network.module
else:
network = self.network
params = {
'state_dict': self.network.state_dict(),
'state_dict': network.state_dict(),
'word_dict': self.word_dict,
'feature_dict': self.feature_dict,
'args': self.args,
Expand Down

0 comments on commit 50d0e49

Please sign in to comment.