Skip to content

Commit

Permalink
[quant] change observer FQNs generated in prepare step (pytorch#65420)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#65420

Context: In some FB use cases we have a need to map observer stats from train model checkpoint to inference model. We observerd that some buffer names are different becuase the intermediate activation tensors
are generated differently across train and inference model. More details in https://fb.quip.com/PtGcAR0S5CQP

Currently, for each observer (activation_post_process), the FQN of the module inserted is determined based on the FQN of the input tensor it is observing.

In this change we change the observer FQN to include the FQN of the op/module it is observing rather than tensor/intermediate op names along with the “input”/“output” detail.

Before
```
def forward(self, x):
    x_activation_post_process_0 = self.x_activation_post_process_0(x);  x = None
    mods1_w = self.mods1.w
    mods1_w_activation_post_process_0 = self.mods1_w_activation_post_process_0(mods1_w);  mods1_w = None
    mods1_b = self.mods1.b
    linear = torch.nn.functional.linear(x_activation_post_process_0, mods1_w_activation_post_process_0, bias = mods1_b);  x_activation_post_process_0 = mods1_w_activation_post_process_0 = mods1_b = None
    linear_activation_post_process_0 = self.linear_activation_post_process_0(linear);  linear = None
    return linear_activation_post_process_0
```

After
```
def forward(self, x):
    mods1_input_activation_post_process_0 = self.mods1_input_activation_post_process_0(x);  x = None
    mods1_w = self.mods1.w
    mods1_w_activation_post_process_0 = self.mods1_w_activation_post_process_0(mods1_w);  mods1_w = None
    mods1_b = self.mods1.b
    linear = torch.nn.functional.linear(mods1_input_activation_post_process_0, mods1_w_activation_post_process_0, bias = mods1_b);  x_activation_post_process_0 = mods1_w_activation_post_process_0 = mods1_b = None
    mods1_output_activation_post_process_0 = self.mods1_output_activation_post_process_0(linear);  linear = None
    return mods1_output_activation_post_process_0
```

Test Plan:
python test/test_quantization.py test_observer_fqn

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D31088652

fbshipit-source-id: 2f1526f578a13000b34cfd30d11f16f402fd3447
  • Loading branch information
supriyar authored and facebook-github-bot committed Sep 23, 2021
1 parent a012216 commit 767a104
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 21 deletions.
4 changes: 2 additions & 2 deletions test/quantization/fx/test_equalize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_input_weight_equalization_activation_values(self):
inp_counter = 0
weight_counter = 0
for node in convert_ref.graph.nodes:
if "weight" not in node.name and node.op == 'call_module' and \
if "w_activation_post" not in node.name and node.op == 'call_module' and \
isinstance(modules[str(node.target)], MinMaxObserver):
# Check min/max values of input activation layers
exp_min_val, exp_max_val = exp_inp_act_vals[inp_counter]
Expand All @@ -533,7 +533,7 @@ def test_input_weight_equalization_activation_values(self):

elif node.op == 'call_module' and isinstance(modules[str(node.target)], MinMaxObserver):
# Check min/max values of weight activation layers
assert("weight" in node.name)
assert("w_activation_post" in node.name)
exp_min_val, exp_max_val = exp_weight_act_vals[weight_counter]
self.assertEqual(modules[str(node.target)].min_val, exp_min_val)
self.assertEqual(modules[str(node.target)].max_val, exp_max_val)
Expand Down
54 changes: 54 additions & 0 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3044,6 +3044,60 @@ def forward(self, x):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)

def test_observer_fqn(self):
"""
Test to make sure the observer FQN is based on the quantizable op/module that it is observing
and uses the modules FQN to determine the observer name.
"""
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)


def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)


class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = Linear()
self.mods3 = torch.nn.Linear(5, 5)

def forward(self, x):
x = self.mods1(x)
x = torch.add(x, 4)
x = self.mods2(x)
y = torch.add(x, 2)
z = torch.mul(x, 5)
a = self.mods3(y)
return a, z

model = M().eval()

prepared = prepare_fx(model, {"": default_qconfig})
name_list = []
for name, mod in prepared.named_modules():
if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver):
assert "mods" in name
name_list.append(name)
expected_name_list = ['mods1_0_input_activation_post_process_0',
'mods1_0_w_activation_post_process_0',
'mods1_0_output_activation_post_process_0',
'mods1_1_w_activation_post_process_0',
'mods1_1_output_activation_post_process_0',
'mods2_w_activation_post_process_0',
'mods2_output_activation_post_process_0',
'mods3_output_activation_post_process_0']
assert name_list == expected_name_list


@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
Expand Down
55 changes: 36 additions & 19 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,13 @@ def update_qconfig_for_fusion(

def insert_observer(
node: Node,
observed_op: Node,
observer: ObserverBase,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
input_or_output: str,
) -> Node:
"""
Attaches `observer` to `model`, and creates a node which calls
Expand All @@ -222,10 +225,15 @@ def insert_observer(
if model_device:
observer.to(model_device)
# add observer module as attribute
# NOTE: We get the FQN of the module/op being observed here using the node_name_to_scope
# Please don't change/update this behavior as it might impact how observer stats are transferred
# from the train model to the inference model for some models.
obs_name_prefix, _ = node_name_to_scope[observed_op.name]
obs_name_prefix = node.name if obs_name_prefix == '' else obs_name_prefix
if is_equalization_observer(observer):
prefix = node.name + '_equalization_process_'
else:
prefix = node.name + '_activation_post_process_'
prefix = obs_name_prefix + '_' + input_or_output + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
Expand Down Expand Up @@ -312,6 +320,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
Expand All @@ -325,20 +334,19 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
new_inner_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, modules,
graph, node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
qhandler, prepare_custom_config_dict, node_name_to_scope)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)

if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)

# default (no observer)
new_arg = arg

is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)

obs_type = "input"
if not is_standalone_module:
# regular flow for most nodes, except standalone modules
is_weight = node_arg_is_weight(node, arg)
Expand All @@ -353,7 +361,10 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
bias_needs_obs = \
(is_bias and activation_dtype(qconfig) == torch.float16) and \
weight_dtype(qconfig) == torch.float16

if weight_needs_obs:
obs_type = "w"
elif bias_needs_obs:
obs_type = "b"
arg_dtype = node_name_to_target_dtype[arg.name]
node_dtype = node_name_to_target_dtype[node.name]
dtype_changes_and_second_dtype_not_float = (
Expand Down Expand Up @@ -423,7 +434,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg(

if existing_obs_node is None:
new_obs_node = insert_observer(
arg, new_obs_mod, model, modules, graph)
arg, node, new_obs_mod, model, modules, graph, node_name_to_scope, obs_type)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs_node.name] = node_dtype
# override this arg to be the observed arg
Expand All @@ -443,6 +454,7 @@ def maybe_insert_input_observers_for_node(
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Expand All @@ -469,15 +481,15 @@ def maybe_insert_input_observers_for_node(
new_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
qhandler, prepare_custom_config_dict, node_name_to_scope)
new_args.append(new_arg)

new_kwargs = {}
for k, kwarg in node.kwargs.items():
new_kwarg = maybe_insert_input_observer_for_arg_or_kwarg(
node, kwarg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
qhandler, prepare_custom_config_dict, node_name_to_scope)
new_kwargs[k] = new_kwarg

# assign the new args and kwargs to the node, inplace
Expand All @@ -492,6 +504,7 @@ def maybe_insert_input_equalization_observers_for_node(
graph: Graph,
node_name_to_target_dtype: Dict[str, Any],
is_branch: bool,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If `node` needs to be equalized, find the input/weight observers it needs in
Expand Down Expand Up @@ -521,7 +534,7 @@ def maybe_insert_input_equalization_observers_for_node(

new_eq_obs_mod = act_eq_process_ctr()
new_eq_obs_node = insert_observer(
arg, new_eq_obs_mod, model, modules, graph)
arg, node, new_eq_obs_mod, model, modules, graph, node_name_to_scope, "input")

# set the type, so the next node can read it
node_name_to_target_dtype[new_eq_obs_node.name] = node_name_to_target_dtype[arg.name]
Expand All @@ -540,6 +553,7 @@ def maybe_insert_output_observer_for_node(
node_name_to_target_dtype: Dict[str, Any],
matched_pattern: Any,
qhandler: Optional[QuantizeHandler],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Optional[Node]:
"""
If `node` needs an output observer, creates it, inserts it into `graph`
Expand Down Expand Up @@ -581,7 +595,7 @@ def maybe_insert_output_observer_for_node(
matched_pattern,
act_post_process_ctr)
observer = act_post_process_ctr()
new_obs = insert_observer(node, observer, model, modules, graph)
new_obs = insert_observer(node, node, observer, model, modules, graph, node_name_to_scope, "output")
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = \
node_name_to_target_dtype[node.name]
Expand All @@ -597,6 +611,7 @@ def maybe_insert_observers_before_graph_output(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> None:
"""
If the output needs to be quantized and there are any nodes
Expand Down Expand Up @@ -626,6 +641,7 @@ def _recursive_maybe_replace_node_with_obs(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Argument:
"""
Navigate an arbitrary data structure of lists, tuples, dicts.
Expand Down Expand Up @@ -655,7 +671,7 @@ def _recursive_maybe_replace_node_with_obs(
'Quantizing the output node without a qconfig is not supported'
observer_mod = qconfig.activation()
observer_node = insert_observer(
maybe_node, observer_mod, model, modules, graph)
maybe_node, maybe_node, observer_mod, model, modules, graph, node_name_to_scope, "input")
return observer_node
else:
return maybe_node
Expand All @@ -664,7 +680,7 @@ def _recursive_maybe_replace_node_with_obs(
for inner_node in maybe_node:
results.append(_recursive_maybe_replace_node_with_obs(
inner_node, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph))
qconfig_map, model, modules, graph, node_name_to_scope))
if isinstance(maybe_node, list):
return results
else:
Expand All @@ -674,7 +690,7 @@ def _recursive_maybe_replace_node_with_obs(
for k, inner_v in maybe_node.items():
results_dict[k] = _recursive_maybe_replace_node_with_obs(
inner_v, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph)
qconfig_map, model, modules, graph, node_name_to_scope)
return results_dict
else:
return results
Expand All @@ -684,7 +700,7 @@ def _recursive_maybe_replace_node_with_obs(
new_args.append(
_recursive_maybe_replace_node_with_obs(
old_arg, output_target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph))
qconfig_map, model, modules, graph, node_name_to_scope))

graph_output_node.args = new_args # type: ignore[assignment]

Expand Down Expand Up @@ -859,6 +875,7 @@ def insert_observers_for_model(
equalization_config_map: Dict[str, Any],
input_quantized_idxs: List[int],
output_quantized_idxs: List[int],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Optional[Node]:
"""
Inserts observers, using the following high level algorithm:
Expand Down Expand Up @@ -985,12 +1002,12 @@ def insert_observers_for_model(
maybe_insert_input_observers_for_node(
node, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
qhandler, prepare_custom_config_dict, node_name_to_scope)

# Insert equalization input observers if needed
maybe_insert_input_equalization_observers_for_node(
node, equalization_qconfig, model, modules, graph,
node_name_to_target_dtype, is_quantized_branch)
node_name_to_target_dtype, is_quantized_branch, node_name_to_scope)

is_last_node_of_pattern = root_node is node
is_general_tensor_value_op = \
Expand All @@ -1003,7 +1020,7 @@ def insert_observers_for_model(
# this returns the new observer node if it was needed
maybe_output_obs_node = maybe_insert_output_observer_for_node(
node, model, modules, graph, matches,
node_name_to_target_dtype, pattern, qhandler)
node_name_to_target_dtype, pattern, qhandler, node_name_to_scope)
if maybe_output_obs_node is not None:
# Update users of original node to use the output observer
# instead. For example, change
Expand Down Expand Up @@ -1040,7 +1057,7 @@ def insert_observers_for_model(
maybe_insert_observers_before_graph_output(
node, output_quantized_idxs,
node_name_to_target_dtype, qconfig_map,
model, modules, graph)
model, modules, graph, node_name_to_scope)

#
# After this point, the current node has input and output observers
Expand Down Expand Up @@ -1218,7 +1235,7 @@ def prepare(
model, modules, matches, qconfig_map,
model.graph, prepare_custom_config_dict,
equalization_qconfig_map,
input_quantized_idxs, output_quantized_idxs)
input_quantized_idxs, output_quantized_idxs, node_name_to_scope)

save_state(model, qconfig_map, node_name_to_scope, patterns,
prepare_custom_config_dict, equalization_qconfig_map)
Expand Down

0 comments on commit 767a104

Please sign in to comment.