diff --git a/src/utils/convert_model_to_torchscript.py b/src/utils/convert_model_to_torchscript.py index c24cfa9..aa1e0a7 100644 --- a/src/utils/convert_model_to_torchscript.py +++ b/src/utils/convert_model_to_torchscript.py @@ -49,6 +49,7 @@ def load_and_convert(args): io_groups.add_argument("--model_param_path", required=True, type=str, default=None) io_groups.add_argument("--out_dir", type=str, default="./") + parser.add_argument("--cpu", type=bool, default=True) args = parser.parse_args() # overwrite network model parameter from json file if specified