Skip to content

Commit

Permalink
Merge pull request NVIDIA#693 from hXl3s/RN50/ngc-checkpoint-update
Browse files Browse the repository at this point in the history
[ConvNets/PyT] Fixed distributed checkpoint loading
  • Loading branch information
nv-kkudrynski authored Sep 18, 2020
2 parents a74236a + 72f40b8 commit 94518be
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 5 additions & 1 deletion PyTorch/Classification/ConvNets/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ def main(args):

if args.weights is not None:
weights = torch.load(args.weights)

#Temporary fix to allow NGC checkpoint loading
weights = {k.replace("module.", ""): v for k, v in weights.items()}

model.load_state_dict(weights)

model = model.cuda()

if args.precision in ["AMP", "FP16"]:
model = network_to_half(model)
model = model.half()


model.eval()
Expand Down
4 changes: 4 additions & 0 deletions PyTorch/Classification/ConvNets/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ def _worker_init_fn(id):
)
)
pretrained_weights = torch.load(args.pretrained_weights)

#Temporary fix to allow NGC checkpoint loading

pretrained_weights = {k.replace("module.", ""): v for k, v in pretrained_weights.items()}
else:
print("=> no pretrained weights found at '{}'".format(args.resume))

Expand Down

0 comments on commit 94518be

Please sign in to comment.