-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
Copy pathonnx_registry_tutorial.py
271 lines (220 loc) · 9.87 KB
/
onnx_registry_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# -*- coding: utf-8 -*-
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
**Extending the ONNX exporter operator support** ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_
Extending the ONNX Exporter Operator Support
============================================
**Authors:** `Ti-Tai Wang <[email protected]>`_, `Justin Chu <[email protected]>`_
"""
###############################################################################
# Overview
# --------
#
# This tutorial describes how you can create ONNX implementation for unsupported PyTorch operators
# or replace existing implementation with your own.
#
# We will cover three scenarios that require extending the ONNX exporter's operator support:
#
# * Overriding the implementation of an existing PyTorch operator
# * Using custom ONNX operators
# * Supporting a custom PyTorch operator
#
# What you will learn:
#
# - How to override or add support for PyTorch operators in ONNX.
# - How to integrate custom ONNX operators for specialized runtimes.
# - How to implement and translate custom PyTorch operators to ONNX.
#
# Prerequisites
# ~~~~~~~~~~~~~
#
# Before starting this tutorial, make sure you have completed the following prerequisites:
#
# * ``torch >= 2.6``
# * The target PyTorch operator
# * Completed the
# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_
# before proceeding
# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__
#
# Overriding the implementation of an existing PyTorch operator
# -------------------------------------------------------------
#
# Although the ONNX exporter team does their best efforts to support all PyTorch operators, some of them
# might not be supported yet. In this section, we will demonstrate how you can add
# unsupported PyTorch operators to the ONNX Registry.
#
# .. note::
# The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existing
# PyTorch operator with a custom one.
# Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leverage
# this and replace the implementation of ``torch.ops.aten.add.Tensor`` with a custom implementation the same way we would
# if the operator was not implemented by the ONNX exporter.
#
# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message
# similar to:
#
# .. code-block:: python
#
# No decompositions registered for [...]
#
# The error message indicates that the unsupported PyTorch operator is ``torch.ops.aten.add.Tensor``.
# The operator is of type ``<class 'torch._ops.OpOverload'>``, and this operator is what we will use as the
# target to register our custom implementation.
import torch
import onnxscript
# Opset 18 is the standard supported version as of PyTorch 2.6
from onnxscript import opset18 as op
# Create a model that uses the operator torch.ops.aten.add.Tensor
class Model(torch.nn.Module):
def forward(self, input_x, input_y):
return torch.ops.aten.add.Tensor(input_x, input_y)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# All attributes must be annotated with type hints.
def custom_aten_add(self, other, alpha: float = 1.0):
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
# To distinguish the custom implementation from the builtin one, we switch the order of the inputs
return op.Add(other, self)
x = torch.tensor([1.0])
y = torch.tensor([2.0])
# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.
onnx_program = torch.onnx.export(
Model().eval(),
(x, y),
dynamo=True,
custom_translation_table={
torch.ops.aten.add.Tensor: custom_aten_add,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
######################################################################
# Now let's inspect the model and verify the model is using the custom implementation.
print(onnx_program.model)
######################################################################
# The translation is using our custom implementation: In node ``node_Add_0``, ``input_y`` now
# comes first, and ``input_x`` comes second.
#
# We can use ONNX Runtime to run the model and verify the results by calling
# the :class:`torch.onnx.ONNXProgram` directly on the input tensors.
result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))
######################################################################
# Using custom ONNX operators
# ---------------------------
#
# In this case, we create a model with standard PyTorch operators, but the runtime
# (such as Microsoft's ONNX Runtime) can provide a custom implementation for that kernel, effectively replacing the
# existing implementation.
#
# In the following example, we use the ``com.microsoft.Gelu`` operator provided by ONNX Runtime,
# which is not the same ``Gelu`` from ONNX spec.
class GeluModel(torch.nn.Module):
def forward(self, input_x):
return torch.ops.aten.gelu(input_x)
# Create a namespace for the custom operator using ONNX Script
# ``com.microsoft`` is an official ONNX Runtime namespace
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
# NOTE: All attributes must be annotated with type hints.
# The function must be scripted using the ``@onnxscript.script()`` decorator when
# using operators from custom domains. This may be improved in future versions.
from onnxscript import FLOAT
@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
return microsoft_op.Gelu(self)
onnx_program = torch.onnx.export(
GeluModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.aten.gelu.default: custom_aten_gelu,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
######################################################################
# Let's inspect the model and verify the model uses op_type ``Gelu``
# from namespace ``com.microsoft``.
#
print(onnx_program.model)
######################################################################
# Similar to the previous example, we can use ONNX Runtime to run the model and verify the results.
result = onnx_program(x)[0]
torch.testing.assert_close(result, torch.ops.aten.gelu(x))
######################################################################
# Supporting a custom PyTorch operator
# ------------------------------------
#
# In this case, the operator is an operator that is user implemented and registered to PyTorch.
#
# In the following example, we would like to use a custom operator
# that takes one tensor input, and returns one output. The operator adds
# the input to itself, and returns the rounded result.
#
# Firstly, we assume the custom operator is implemented and registered with ``torch.library.custom_op()``.
# You can refer to `Creating new custom ops in Python <https://pytorch.org/docs/stable/library.html#torch.library.custom_op>`_
# for a detailed guide on how to create custom operators.
# Define and use the operator in PyTorch
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
return torch.round(input + input)
@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
return torch.empty_like(tensor_x)
class AddAndRoundModel(torch.nn.Module):
def forward(self, input):
return add_and_round_op(input)
# Implement the custom operator in ONNX using ONNX Script
def onnx_add_and_round(input):
return op.Round(op.Add(input, input))
onnx_program = torch.onnx.export(
AddAndRoundModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
},
)
# Optimize the ONNX graph to remove redundant nodes
onnx_program.optimize()
print(onnx_program)
######################################################################
# The translation is using our custom implementation to translate the ``torch.ops.mylibrary.add_and_round_op.default``
# operator in the :class:`torch.export.ExportedProgram`` to the ONNX operator ``Add`` and ``Round``.
#
######################################################################
# Finally we verify the results.
result = onnx_program(x)[0]
torch.testing.assert_close(result, add_and_round_op(x))
######################################################################
# Conclusion
# ----------
#
# Congratulations! In this tutorial, we explored the ``custom_translation_table`` option and
# discovered how to create custom implementations for unsupported or existing PyTorch operators
# using ONNX Script.
#
# Finally, we leveraged ONNX Runtime to execute the model and compare the results with PyTorch,
# providing us with a comprehensive understanding of handling unsupported
# operators in the ONNX ecosystem.
#
# Further reading
# ---------------
#
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
# not necessarily in the order they are listed.
# Feel free to jump directly to specific topics of your interest or
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
#
# .. include:: /beginner_source/onnx/onnx_toc.txt
#
# .. toctree::
# :hidden:
#