Skip to content

Commit

Permalink
remove_identity_and_update_affine
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenChou0119 committed Nov 4, 2024
1 parent 3608f12 commit 88bef58
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions mqbench/convert_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ def convert_onnx(model: GraphModule, input_shape_dict, dummy_input, onnx_model_p
do_constant_folding=False,
custom_opsets={'mqbench_custom' : opset_version},
)

import onnx
import copy
def remove_identity_and_update_affine(model):
graph = model.graph
nodes_to_remove = []
updated_nodes = []
initializer_names = [n.name for n in graph.initializer]
updated_initializer = []
for node in graph.node:
if node.op_type == 'LearnablePerTensorAffine' or node.op_type == 'FakeQuantizeLearnablePerchannelAffine':
for quant_name in node.input:
if quant_name.endswith("scale") or quant_name.endswith("zero_point") or quant_name.endswith("weight"):
for prev_node in updated_nodes:
if prev_node.op_type == 'Identity' and prev_node.output[0] == quant_name:
quant_name_identidy_input = prev_node.input[0]
for prev_prev_node in graph.initializer:
if prev_prev_node.name == quant_name_identidy_input:
if quant_name not in initializer_names:
new_initializer = copy.deepcopy(prev_prev_node)
new_initializer.name = quant_name
updated_initializer.append(new_initializer)
node.input[2] = quant_name
prev_node.output[0] = prev_node.input[0]
nodes_to_remove.append(prev_node)
break
updated_nodes.append(node)

graph.node.clear()
graph.node.extend(updated_nodes)
graph.initializer.extend(updated_initializer)

for node in nodes_to_remove:
graph.node.remove(node)

return model
tmp_model = onnx.load(onnx_model_path)
updated_model = remove_identity_and_update_affine(tmp_model)
onnx.save_model(updated_model, onnx_model_path)
import os
os.system("polygraphy surgeon sanitize --fold-constants {} -o {}".format(onnx_model_path, onnx_model_path))

Expand Down

0 comments on commit 88bef58

Please sign in to comment.