Skip to content

Commit

Permalink
fx quant: fix edge case with copynode after user function (pytorch#55710
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#55710

In the current code, there is an edge case which leads to an error
after the prepare step:

1. have a pattern like this:

```
user_func_unmatched_to_qhandler -> node_matched_to_copy_node_qhandler
```

2. the user function returns a type which is not observable (i.e. not a
Tensor)

3. if this is run through `prepare_fx`, calibrating it with data leads
to a runtime error, because observers cannot observe non-tensor types.

This PR fixes the issue.  If a node matched to `CopyNodeQuantizeHandler`
is after an unmatched node, we delete the observer.

Test Plan:
```
python test/test_quantization.py TestQuantizeFx.test_no_obs_between_unmatched_node_and_copy_node
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D27686811

fbshipit-source-id: 320be41b1f383c6352ff89fb39a9f480822a3bb2
  • Loading branch information
vkuzo authored and facebook-github-bot committed Apr 12, 2021
1 parent 3f8d476 commit ec9b20d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
32 changes: 32 additions & 0 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def forward(self, x, y):
x = self.relu(x)
return x

@torch.fx.wrap
def _user_func_with_complex_return_type(x):
return list(torch.split(x, 1, 1))

class TestFuseFx(QuantizationTestCase):
def test_fuse_conv_bn_relu(self):
class M(torch.nn.Module):
Expand Down Expand Up @@ -2062,6 +2066,34 @@ def forward(self, x):
"mods1_1_scale_0", "mods1_1_zero_point_0"]:
self.assertTrue(hasattr(m, attr_name))

def test_no_obs_between_unmatched_node_and_copy_node(self):
"""
Verifies that an observer is not inserted between an unmatched
node and a node matched to CopyNodeQuantizeHandler. This is done
because observers require activations to be Tensors, and there is
no guarantee that an output of an unmatched node is a Tensor.
"""

class M(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()

def forward(self, x):
x = _user_func_with_complex_return_type(x)
x1 = x[0] + 1
return x1, x[1]

m = M().eval()

qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fx(m, qconfig_dict)
# if an observer is inserted after _user_func_with_complex_return_type,
# the following call will fail
mp(torch.randn(4, 4, 4, 4))
mc = convert_fx(mp)
mc(torch.randn(4, 4, 4, 4))

def test_fold_quant_dequant(self):
""" Test that the sequence of quant-dequant nodes in the
graph, get folded and we erase the extra dequant nodes.
Expand Down
32 changes: 29 additions & 3 deletions torch/quantization/fx/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def handle_copy_nodes(
observed_nodes: Set[Node] = set()
copy_nodes: Set[Node] = set()
non_tensor_input_binary_op_nodes: Set[Node] = set()
unmatched_nodes: Set[Node] = set()
app_to_remove: Set[Node] = set()
env: Dict[Any, Any] = {}

Expand Down Expand Up @@ -465,9 +466,33 @@ def in_nodes(a: Argument, nodes: Set[Node]) -> bool:
copy_nodes.add(node)
# if previous node is observed, the copy node will be observed as well
if in_nodes(node.args[0], observed_nodes):
observed_nodes.add(node)
prev_node = node.args[0]
if (
isinstance(prev_node, Node) and
prev_node.op == "call_module" and
is_activation_post_process(modules[prev_node.target]) # type: ignore
):
prev_prev_node = prev_node.args[0]
# If previous node is unmatched, the input to copy node should not
# be observed. For example, in the pattern of
#
# user_node_unmatched -> obs -> copy_node_matched -> next_node
#
# we delete `obs`, because user_node_unmatched is not quantizeable,
# and the input to copy_node_matched does not need observation.
if in_nodes(prev_prev_node, unmatched_nodes):
app_to_remove.add(prev_node)
observed_nodes.remove(prev_node)
else:
observed_nodes.add(node)
else:
observed_nodes.add(node)


if all_node_args_have_no_tensors(node, modules, cache_for_no_tensor_check):
non_tensor_input_binary_op_nodes.add(node)
if root_node is None and node.op != 'placeholder':
unmatched_nodes.add(node)

# rule 3: for special node, we'll just remove observer for its input
special_nodes = [
Expand Down Expand Up @@ -1189,8 +1214,9 @@ def _fold_quant_dequant(self, quantized: QuantizedGraphModule) -> QuantizedGraph
# and all it's inputs.
if len(quant_uses) == 1:
quantized.graph.erase_node(node)
for arg in quant_args[1 :]:
quantized.graph.erase_node(arg)
for arg in quant_args[1:]:
if isinstance(arg, Node):
quantized.graph.erase_node(arg)
return quantized

def convert(self, model: GraphModule, is_reference: bool = False,
Expand Down

0 comments on commit ec9b20d

Please sign in to comment.