Skip to content

Commit

Permalink
[quant][[fx] update observer_fqn to not depend on node.name (pytorch#…
Browse files Browse the repository at this point in the history
…66767)

Summary:
Pull Request resolved: pytorch#66767

Make observer fqn in prepare step independent of input_node/observed_node name.
This change names the observers as `{input/output}_activation_post_process_{idx}` where idx will be incremented for each new observer instance and is guaranteed to be unique.

Test Plan:
python test/test_quantization.py test_observer_fqn

Imported from OSS

Reviewed By: anjali411

Differential Revision: D31752052

fbshipit-source-id: e0995b1ef33a99d5b012133fe92d303d55a73b7d
  • Loading branch information
supriyar authored and facebook-github-bot committed Oct 23, 2021
1 parent 83f70db commit d13829e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 59 deletions.
30 changes: 15 additions & 15 deletions test/quantization/fx/test_equalize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,21 +525,21 @@ def test_input_weight_equalization_activation_values(self):
inp_counter = 0
weight_counter = 0
for node in convert_ref.graph.nodes:
if "w_activation_post" not in node.name and node.op == 'call_module' and \
isinstance(modules[str(node.target)], MinMaxObserver):
# Check min/max values of input activation layers
exp_min_val, exp_max_val = exp_inp_act_vals[inp_counter]
self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
inp_counter += 1

elif node.op == 'call_module' and isinstance(modules[str(node.target)], MinMaxObserver):
# Check min/max values of weight activation layers
assert("w_activation_post" in node.name)
exp_min_val, exp_max_val = exp_weight_act_vals[weight_counter]
self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
weight_counter += 1
users = list(node.users)
if node.op == 'call_module' and isinstance(modules[str(node.target)], MinMaxObserver):
if len(users) == 1 and users[0].target == torch.nn.functional.linear and users[0].args[1] == node:
# Check min/max values of weight activation layers
exp_min_val, exp_max_val = exp_weight_act_vals[weight_counter]
self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
weight_counter += 1
else:
# Check min/max values of input activation layers
exp_min_val, exp_max_val = exp_inp_act_vals[inp_counter]
self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
inp_counter += 1


def check_orig_and_eq_graphs(self, orig_model, eq_model):
""" Given a non-equalized model and an equalized model, check that the
Expand Down
17 changes: 8 additions & 9 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3143,16 +3143,15 @@ def forward(self, x):
name_list = []
for name, mod in prepared.named_modules():
if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver):
assert "mods" in name
name_list.append(name)
expected_name_list = ['mods1_0_input_activation_post_process_0',
'mods1_0_w_activation_post_process_0',
'mods1_0_output_activation_post_process_0',
'mods1_1_w_activation_post_process_0',
'mods1_1_output_activation_post_process_0',
'mods2_w_activation_post_process_0',
'mods2_output_activation_post_process_0',
'mods3_output_activation_post_process_0']
expected_name_list = ['activation_post_process_0',
'activation_post_process_1',
'activation_post_process_2',
'activation_post_process_3',
'activation_post_process_4',
'activation_post_process_6',
'activation_post_process_7',
'activation_post_process_10']
assert name_list == expected_name_list

def test_linear_lowering(self):
Expand Down
51 changes: 16 additions & 35 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ def insert_observer(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
input_or_output: str,
) -> Node:
"""
Attaches `observer` to `model`, and creates a node which calls
Expand All @@ -216,15 +214,10 @@ def insert_observer(
if model_device:
observer.to(model_device)
# add observer module as attribute
# NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope
# Please don't change/update this behavior as it might impact how observer stats are transferred
# from the train model to the inference model for some models.
obs_name_prefix, _ = node_name_to_scope[observed_op.name]
obs_name_prefix = node.name if obs_name_prefix == '' else obs_name_prefix
if is_equalization_observer(observer):
prefix = node.name + '_equalization_process_'
else:
prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_'
prefix = 'activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
Expand Down Expand Up @@ -311,7 +304,6 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
Expand All @@ -325,7 +317,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
new_inner_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, modules,
graph, node_name_to_target_dtype,
qhandler, prepare_custom_config_dict, node_name_to_scope)
qhandler, prepare_custom_config_dict)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)

Expand All @@ -337,7 +329,6 @@ def maybe_insert_input_observer_for_arg_or_kwarg(

is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)
obs_type = "input"
if not is_standalone_module:
# regular flow for most nodes, except standalone modules
is_weight = node_arg_is_weight(node, arg)
Expand All @@ -352,10 +343,6 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
bias_needs_obs = \
(is_bias and activation_dtype(qconfig) == torch.float16) and \
weight_dtype(qconfig) == torch.float16
if weight_needs_obs:
obs_type = "w"
elif bias_needs_obs:
obs_type = "b"
arg_dtype = node_name_to_target_dtype[arg.name]
node_dtype = node_name_to_target_dtype[node.name]
dtype_changes_and_second_dtype_not_float = (
Expand Down Expand Up @@ -425,7 +412,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg(

if existing_obs_node is None:
new_obs_node = insert_observer(
arg, node, new_obs_mod, model, modules, graph, node_name_to_scope, obs_type)
arg, node, new_obs_mod, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs_node.name] = node_dtype
# override this arg to be the observed arg
Expand All @@ -445,7 +432,6 @@ def maybe_insert_input_observers_for_node(
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Expand All @@ -472,15 +458,15 @@ def maybe_insert_input_observers_for_node(
new_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict, node_name_to_scope)
qhandler, prepare_custom_config_dict)
new_args.append(new_arg)

new_kwargs = {}
for k, kwarg in node.kwargs.items():
new_kwarg = maybe_insert_input_observer_for_arg_or_kwarg(
node, kwarg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict, node_name_to_scope)
qhandler, prepare_custom_config_dict)
new_kwargs[k] = new_kwarg

# assign the new args and kwargs to the node, inplace
Expand All @@ -495,7 +481,6 @@ def maybe_insert_input_equalization_observers_for_node(
graph: Graph,
node_name_to_target_dtype: Dict[str, Any],
is_branch: bool,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If `node` needs to be equalized, find the input/weight observers it needs in
Expand Down Expand Up @@ -525,7 +510,7 @@ def maybe_insert_input_equalization_observers_for_node(

new_eq_obs_mod = act_eq_process_ctr()
new_eq_obs_node = insert_observer(
arg, node, new_eq_obs_mod, model, modules, graph, node_name_to_scope, "input")
arg, node, new_eq_obs_mod, model, modules, graph)

# set the type, so the next node can read it
node_name_to_target_dtype[new_eq_obs_node.name] = node_name_to_target_dtype[arg.name]
Expand All @@ -544,7 +529,6 @@ def maybe_insert_output_observer_for_node(
node_name_to_target_dtype: Dict[str, Any],
matched_pattern: Any,
qhandler: Optional[QuantizeHandler],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Optional[Node]:
"""
If `node` needs an output observer, creates it, inserts it into `graph`
Expand Down Expand Up @@ -585,7 +569,7 @@ def maybe_insert_output_observer_for_node(
qconfig,
matched_pattern)
observer = act_post_process_ctr()
new_obs = insert_observer(node, node, observer, model, modules, graph, node_name_to_scope, "output")
new_obs = insert_observer(node, node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = \
node_name_to_target_dtype[node.name]
Expand All @@ -601,7 +585,6 @@ def maybe_insert_observers_before_graph_output(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If the output needs to be quantized and there are any nodes
Expand Down Expand Up @@ -631,7 +614,6 @@ def _recursive_maybe_replace_node_with_obs(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Argument:
"""
Navigate an arbitrary data structure of lists, tuples, dicts.
Expand Down Expand Up @@ -661,7 +643,7 @@ def _recursive_maybe_replace_node_with_obs(
'Quantizing the output node without a qconfig is not supported'
observer_mod = qconfig.activation()
observer_node = insert_observer(
maybe_node, maybe_node, observer_mod, model, modules, graph, node_name_to_scope, "input")
maybe_node, maybe_node, observer_mod, model, modules, graph)
return observer_node
else:
return maybe_node
Expand All @@ -670,7 +652,7 @@ def _recursive_maybe_replace_node_with_obs(
for inner_node in maybe_node:
results.append(_recursive_maybe_replace_node_with_obs(
inner_node, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph, node_name_to_scope))
qconfig_map, model, modules, graph))
if isinstance(maybe_node, list):
return results
else:
Expand All @@ -680,7 +662,7 @@ def _recursive_maybe_replace_node_with_obs(
for k, inner_v in maybe_node.items():
results_dict[k] = _recursive_maybe_replace_node_with_obs(
inner_v, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph, node_name_to_scope)
qconfig_map, model, modules, graph)
return results_dict
else:
return results
Expand All @@ -690,7 +672,7 @@ def _recursive_maybe_replace_node_with_obs(
new_args.append(
_recursive_maybe_replace_node_with_obs(
old_arg, output_target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph, node_name_to_scope))
qconfig_map, model, modules, graph))

graph_output_node.args = new_args # type: ignore[assignment]

Expand Down Expand Up @@ -865,7 +847,6 @@ def insert_observers_for_model(
equalization_config_map: Dict[str, Any],
input_quantized_idxs: List[int],
output_quantized_idxs: List[int],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Optional[Node]:
"""
Inserts observers, using the following high level algorithm:
Expand Down Expand Up @@ -992,12 +973,12 @@ def insert_observers_for_model(
maybe_insert_input_observers_for_node(
node, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict, node_name_to_scope)
qhandler, prepare_custom_config_dict)

# Insert equalization input observers if needed
maybe_insert_input_equalization_observers_for_node(
node, equalization_qconfig, model, modules, graph,
node_name_to_target_dtype, is_quantized_branch, node_name_to_scope)
node_name_to_target_dtype, is_quantized_branch)

is_last_node_of_pattern = root_node is node
is_general_tensor_value_op = \
Expand All @@ -1010,7 +991,7 @@ def insert_observers_for_model(
# this returns the new observer node if it was needed
maybe_output_obs_node = maybe_insert_output_observer_for_node(
node, model, modules, graph, matches,
node_name_to_target_dtype, pattern, qhandler, node_name_to_scope)
node_name_to_target_dtype, pattern, qhandler)
if maybe_output_obs_node is not None:
# Update users of original node to use the output observer
# instead. For example, change
Expand Down Expand Up @@ -1047,7 +1028,7 @@ def insert_observers_for_model(
maybe_insert_observers_before_graph_output(
node, output_quantized_idxs,
node_name_to_target_dtype, qconfig_map,
model, modules, graph, node_name_to_scope)
model, modules, graph)

#
# After this point, the current node has input and output observers
Expand Down Expand Up @@ -1225,7 +1206,7 @@ def prepare(
model, modules, matches, qconfig_map,
model.graph, prepare_custom_config_dict,
equalization_qconfig_map,
input_quantized_idxs, output_quantized_idxs, node_name_to_scope)
input_quantized_idxs, output_quantized_idxs)

save_state(model, qconfig_map, node_name_to_scope, patterns,
prepare_custom_config_dict, equalization_qconfig_map)
Expand Down

0 comments on commit d13829e

Please sign in to comment.