Skip to content

Commit

Permalink
[ONNX] Fix for expand -1 dim value (pytorch#34069)
Browse files Browse the repository at this point in the history
Summary:
PyTorch expand allows size with -1 dim value. -1 dim value means to infer the dimension from input tensor. This can be exported to ONNX expand with 1 dim value since ONNX expand supports two-way broadcast.
Pull Request resolved: pytorch#34069

Reviewed By: hl475

Differential Revision: D20195532

Pulled By: houseroad

fbshipit-source-id: c90e7d51b9d7422c09c5ed6e135ca8263105b8c9
  • Loading branch information
neginraoof authored and facebook-github-bot committed Mar 16, 2020
1 parent 1bac5fd commit 480d184
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
76 changes: 73 additions & 3 deletions test/onnx/expect/TestOperators.test_expand.expect
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,80 @@ graph {
}
}
node {
input: "0"
input: "1"
output: "2"
name: "Expand_1"
name: "Shape_1"
op_type: "Shape"
}
node {
input: "2"
output: "3"
name: "ConstantOfShape_2"
op_type: "ConstantOfShape"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
output: "4"
name: "Constant_3"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\377\377\377\377\377\377\377\377"
}
type: TENSOR
}
}
node {
input: "3"
input: "4"
output: "5"
name: "Mul_4"
op_type: "Mul"
}
node {
output: "6"
name: "Constant_5"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 3
data_type: 7
raw_data: "\004\000\000\000\000\000\000\000\006\000\000\000\000\000\000\000\002\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "6"
input: "5"
output: "7"
name: "Equal_6"
op_type: "Equal"
}
node {
input: "7"
input: "3"
input: "1"
output: "8"
name: "Where_7"
op_type: "Where"
}
node {
input: "0"
input: "8"
output: "9"
name: "Expand_8"
op_type: "Expand"
}
name: "torch-jit-export"
Expand All @@ -41,7 +111,7 @@ graph {
}
}
output {
name: "2"
name: "9"
type {
tensor_type {
elem_type: 1
Expand Down
24 changes: 24 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,30 @@ def forward(self, input, indices):
indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64)
self.run_test(GatherModel(), input=(input, indices))

@skipIfUnsupportedMinOpsetVersion(9)
def test_expand(self):
class ExpandModel(torch.nn.Module):
def forward(self, input):
return input.expand(2, 3, -1)

input = torch.randn(2, 1, 4)
self.run_test(ExpandModel(), input=(input))

class ExpandInferDimModel(torch.nn.Module):
def forward(self, input):
return input.expand(-1, input.size(0))

input = torch.randn(3, 1)
self.run_test(ExpandInferDimModel(), input=(input))

class ExpandTensorSizeModel(torch.nn.Module):
def forward(self, input, size):
return input.expand(size)

input = torch.randn(3,)
size = torch.tensor([-1])
self.run_test(ExpandTensorSizeModel(), input=(input, size))

def test_multinomial(self):
class Multinomial(torch.nn.Module):
def forward(self, weight):
Expand Down
9 changes: 9 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,15 @@ def expand(g, self, size, implicit):
size = sym_help._maybe_get_const(size, 'is')
if not sym_help._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
elif sym_help._is_packed_list(size):
# Expand with -1 dim value means dim is unchanged.
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = view(g, stack(g, size, 0), [-1])
dtype = 4 # dim type is int64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
return g.op("Expand", self, size)


Expand Down

0 comments on commit 480d184

Please sign in to comment.