Skip to content

Commit

Permalink
fix load model.
Browse files Browse the repository at this point in the history
  • Loading branch information
donnyyou committed Nov 19, 2019
1 parent 49f090d commit 1f03a01
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions model/tools/module_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,15 @@ def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
model.load_state_dict(load_dict)

else:
pretrained_dict = torch.load(pretrained)
pretrained_dict = torch.load(pretrained, map_location=map_location)
model_dict = model.state_dict()
load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
load_dict = dict()
for k, v in pretrained_dict.items():
if 'prefix.{}'.format(k) in model_dict:
load_dict['prefix.{}'.format(k)] = v
elif k in model_dict:
load_dict[k] = v

Log.info('Matched Keys: {}'.format(load_dict.keys()))
model_dict.update(load_dict)
model.load_state_dict(model_dict)
Expand Down

0 comments on commit 1f03a01

Please sign in to comment.