diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 2aeb6ca490c9c3..cb6dc08244101f 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -351,7 +351,7 @@ def _model_to_graph(model, args, verbose=False, training=False, if do_constant_folding and _export_onnx_opset_version == 9: params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict) - torch._C._jit_pass_dce(graph) + torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) if verbose: print(graph)