Skip to content

Commit

Permalink
[MaskRCNN/PyT] Update AMP API for inference (NVIDIA#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharathts authored Jan 14, 2021
1 parent 2badf6e commit 0cbadd7
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions PyTorch/Segmentation/MaskRCNN/pytorch/tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ def main():
model = build_detection_model(cfg)
model.to(cfg.MODEL.DEVICE)

# Initialize mixed-precision if necessary
# Initialize mixed-precision
if args.fp16:
use_mixed_precision = True
else:
use_mixed_precision = cfg.DTYPE == "float16"
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model = amp.initialize(model, opt_level=amp_opt_level)

output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
Expand Down

0 comments on commit 0cbadd7

Please sign in to comment.