Skip to content

Commit

Permalink
[ONNX] Export embedding_bag (pytorch#41234)
Browse files Browse the repository at this point in the history
Summary:
Enable export of embedding_bag op to ONNX

Pull Request resolved: pytorch#41234

Reviewed By: houseroad

Differential Revision: D22567470

Pulled By: bzinodev

fbshipit-source-id: 2fcf74e54f3a9dee4588d7877a4ac9eb6c2a3629
  • Loading branch information
neginraoof authored and facebook-github-bot committed Jul 17, 2020
1 parent 7eb71b4 commit 346c69a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
63 changes: 63 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,69 @@ def forward(self, input):
x = torch.tensor([False, True, True])
self.run_test(model, x)

@unittest.skip("Enable once jit trace Tensor.numel as constant is fixed.")
def test_embedding_bag_dynamic(self):
class EmbeddingModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embeddingbag = torch.nn.EmbeddingBag(40, 12, mode='sum')

def forward(self, input):
return self.embeddingbag(input)

model = EmbeddingModel()
x = torch.randint(7, (10, 5))
y = torch.randint(10, (20, 5))
self.run_test(model, x, test_with_inputs=[y],
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': [0],
'output': [0]
})

@skipIfUnsupportedMinOpsetVersion(10)
def test_embedding_bag(self):
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
input = torch.randint(10, (7,))
offset = torch.tensor([0, 2, 5, 6])
self.run_test(model, (input, offset))

model = torch.nn.EmbeddingBag(10, 5, mode='sum', include_last_offset=True)
input = torch.randint(10, (7,))
offset = torch.tensor([0, 2, 5, 6])
self.run_test(model, (input, offset))

model = torch.nn.EmbeddingBag(10, 5, mode='max')
input = torch.randint(10, (7, 5))
self.run_test(model, (input))

@skipIfUnsupportedMinOpsetVersion(10)
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, offset, weights):
return torch.nn.functional.embedding_bag(embedding_matrix, input, offsets=offset,
mode='sum', per_sample_weights=weights)

model = EmbeddingModel()
x = torch.randint(7, (6,))
w = torch.randn(6,)
offset = torch.tensor([0, 2, 5])
embedding_matrix = torch.rand(10, 15)
self.run_test(model, (embedding_matrix, x, offset, w))

@skipIfUnsupportedMinOpsetVersion(10)
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
def forward(self, embedding_matrix, input, weights):
return torch.nn.functional.embedding_bag(embedding_matrix, input,
mode='sum', per_sample_weights=weights)

embedding_matrix = torch.rand(10, 15)
model = EmbeddingModel()
x = torch.randint(7, (2, 3))
w = torch.randn(2, 3)
self.run_test(model, (embedding_matrix, x, w))

@skipIfUnsupportedMinOpsetVersion(8)
def test_meshgrid(self):
class Meshgrid(torch.nn.Module):
Expand Down
84 changes: 84 additions & 0 deletions torch/onnx/symbolic_opset10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from sys import maxsize

import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
Expand Down Expand Up @@ -179,6 +180,89 @@ def flip(g, input, dims):
def fmod(g, input, other):
return g.op("Mod", input, other, fmod_i=1)


@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset):
if scale_grad_by_freq and sym_help._training_mode:
return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode')

from torch.onnx.symbolic_opset9 import size, div, select

# Check if initial indices was 2D. In functional.py:
# offsets is set to torch.arange(0, indices.numel(), indices.size(1))
# Then indices is reshaped to 1D: indices.reshape(-1)
if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \
and len(indices.node().inputs().__next__().type().sizes()) == 2:
# Assert include_last_offset is False
assert not include_last_offset
embeddings = g.op("Gather", embedding_matrix, indices)
dim_0 = size(g, offsets, g.op("Constant", value_t=torch.LongTensor([0])))
dim_1 = div(g, size(g, indices, g.op("Constant", value_t=torch.LongTensor([0]))), dim_0)
dim_2 = g.op("Constant", value_t=torch.LongTensor([-1]))

shape = [dim_0, dim_1, dim_2]
shape = g.op("Concat", *shape, axis_i=0)

if not sym_help._is_none(per_sample_weights):
per_sample_weights = g.op("Unsqueeze", per_sample_weights, axes_i=[1])
embeddings = g.op("Mul", embeddings, per_sample_weights)

embeddings = g.op("Reshape", embeddings, shape)
if mode == 0:
embeddings = g.op("ReduceSum", embeddings, axes_i=[1], keepdims_i=0)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[1], keepdims_i=0)
else:
embeddings = g.op("ReduceMax", embeddings, axes_i=[1], keepdims_i=0)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return embeddings, None, None, None
elif offsets.type().sizes() is not None:
if include_last_offset:
offset_len = offsets.type().sizes()[0] - 1
offsets_extended = offsets
else:
offset_len = offsets.type().sizes()[0]
offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
for i in range(offset_len):
start_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), axes_i=[0])
end_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), axes_i=[0])
axes_ = g.op("Constant", value_t=torch.tensor([0]))
indices_row = g.op("Slice", indices, start_, end_, axes_)

embeddings = g.op("Gather", embedding_matrix, indices_row)
if not sym_help._is_none(per_sample_weights):
per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_)
per_sample_weights_row = g.op("Unsqueeze", per_sample_weights_row, axes_i=[1])
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
if mode == 0:
embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
elif mode == 1:
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
else:
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)

embeddings = g.op("Unsqueeze", embeddings, axes_i=[0])
list_.append(embeddings)

output = g.op("Concat", *list_, axis_i=0)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
return output, None, None, None
else:
return sym_help._onnx_unsupported('embedding_bag with unknown shape of indices')


@parse_args('v', 't', 'i', 'i', 'i')
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
if quant_min not in [0, -128] or quant_max not in [127, 255]:
Expand Down

0 comments on commit 346c69a

Please sign in to comment.