diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index 7890516154..edda69fb8f 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -65,7 +65,10 @@ def average_checkpoints(inputs): averaged_params = collections.OrderedDict() for k, v in params_dict.items(): averaged_params[k] = v - averaged_params[k].div_(num_models) + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models new_state['model'] = averaged_params return new_state