Skip to content

Commit

Permalink
[ONNX] Update peephole pass for prim::ListUnpack (pytorch#46264)
Browse files Browse the repository at this point in the history
Summary:
Update pass that handles prim::ListUnpack in peephole file, so that it also covers the case when input to the node is of ListType.

Fixes pytorch#45816

Pull Request resolved: pytorch#46264

Reviewed By: mrshenli

Differential Revision: D24566070

Pulled By: bzinodev

fbshipit-source-id: 32555487054f6a7fe02cc17c66bcbe81ddf9623e
  • Loading branch information
KsenijaS authored and facebook-github-bot committed Nov 5, 2020
1 parent 5977d1d commit 7a59987
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/onnx/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
"${test_paths[@]}"

# onnxruntime only support py3
Expand Down
20 changes: 20 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,26 @@ def test_groupnorm_noaffine(self):
x = torch.randn(4, 6, 180, 180)
self.run_test(model, x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_listunpack(self):
class ListUnpack(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
a, b = x.shape
return x.new_zeros((a, b))

x = torch.randn(2, 3)
self.run_test(ListUnpack(), x)

class ListUnpackSlice(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
a, b = x.shape[2:]
return x.new_zeros((a, b))

x = torch.randn(2, 3, 4, 5)
self.run_test(ListUnpackSlice(), x)

def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
Expand Down
56 changes: 56 additions & 0 deletions torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,68 @@ static void ReplaceIndexPutWithMaskedScatter(Block* b) {
}
}

// This pass also covers the case when the input to ListUnpack
// is int[] comming from some other op than ListConstruct (like Slice or Shape)
//
// before the pass
// graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
// %1 : None = prim::Constant()
// %2 : int[] = aten::size(%x.1) # <string>:7:9
// %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
// %5 : int[] = prim::ListConstruct(%a.1, %b.1)
// %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1) #
// test/onnx/test_pytorch_onnx_onnxruntime.py:1757:23 return (%6)
//
// after the pass:
// graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
// %1 : None = prim::Constant()
// %2 : int[] = aten::size(%x.1) # <string>:7:9
// %7 : Tensor = onnx::Constant[value={0}]()
// %8 : Tensor = onnx::Gather(%2, %7)
// %9 : Tensor = onnx::Constant[value={1}]()
// %10 : Tensor = onnx::Gather(%2, %9)
// %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
// %5 : int[] = prim::ListConstruct(%8, %10)
// %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1) #
// test/onnx/test_pytorch_onnx_onnxruntime.py:1757:23 return (%6)
static void fuseListAndListUnpack(Block* b) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) {
fuseListAndListUnpack(child_block);
}
if (it->kind() == prim::ListUnpack) {
for (size_t i = 0; i < it->outputs().size(); i++) {
auto output = it->outputs().at(i);
if (it->inputs().size() == 1 &&
it->input()->node()->kind() != prim::ListConstruct &&
it->input()->type()->cast<ListType>() &&
it->input()
->type()
->cast<ListType>()
->getElementType()
->cast<IntType>()) {
Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1);
gather_indices->insertBefore(*it);
gather_indices->t_(
attr::value, at::scalar_to_tensor(at::Scalar(int(i))));
Node* gather_node = b->owningGraph()->create(onnx::Gather, 1);
gather_node->insertBefore(*it);
gather_node->addInput(it->input());
gather_node->addInput(gather_indices->output());
output->replaceAllUsesWith(gather_node->output());
}
}
}
}
}

} // namespace

void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
FuseWithListUnpack(graph->block());
ReplaceAddWithConcat(graph->block());
ReplaceIndexPutWithMaskedScatter(graph->block());
fuseListAndListUnpack(graph->block());
}

} // namespace jit
Expand Down

0 comments on commit 7a59987

Please sign in to comment.