Skip to content

Commit

Permalink
Add support for training apis to support custom ops (microsoft#16601)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Jul 14, 2023
1 parent 19169af commit 9889f0f
Show file tree
Hide file tree
Showing 20 changed files with 602 additions and 107 deletions.
Binary file not shown.
Binary file not shown.
200 changes: 147 additions & 53 deletions orttraining/orttraining/python/orttraining_pybind_state.cc

Large diffs are not rendered by default.

62 changes: 40 additions & 22 deletions orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue, get_ort_device_type
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector, SessionOptions
from onnxruntime.training.api.checkpoint_state import CheckpointState


Expand All @@ -33,6 +33,7 @@ class Module:
state: The checkpoint state object.
eval_model_uri: The path to the evaluation model.
device: The device to run the model on. Default is "cpu".
session_options: The session options to use for the model.
"""

training: bool
Expand All @@ -43,11 +44,13 @@ def __init__(
state: CheckpointState,
eval_model_uri: os.PathLike | None = None,
device: str = "cpu",
session_options: SessionOptions | None = None,
) -> None:
self.training = True
options = device.split(":")
self._device_type = options[0]
device_id = 0 if len(options) < 2 else int(options[1])
self._session_options = session_options if session_options is not None else SessionOptions()

self._device = C.OrtDevice(
get_ort_device_type(self._device_type, device_id),
Expand All @@ -59,42 +62,57 @@ def __init__(
state._state,
os.fspath(eval_model_uri) if eval_model_uri is not None else None,
self._device,
self._session_options,
)
self._state = state

def __call__(self, *user_inputs) -> tuple[np.ndarray] | np.ndarray:
def __call__(self, *user_inputs) -> tuple[np.ndarray, ...] | np.ndarray | tuple[OrtValue, ...] | OrtValue:
"""Invokes either the training or the evaluation step of the model.
Args:
*user_inputs: The inputs to the model.
The user inputs can be either numpy arrays or OrtValues.
Returns:
The outputs of the model.
"""
is_np_input = False
forward_inputs = OrtValueVector()
forward_inputs.reserve(len(user_inputs))
for tensor in user_inputs:
if isinstance(tensor, np.ndarray):
is_np_input = True
forward_inputs.push_back(OrtValue.ortvalue_from_numpy(tensor)._ortvalue)
elif isinstance(tensor, OrtValue):
forward_inputs.push_back(tensor._ortvalue)
else:
raise ValueError(f"Expected input of type: numpy array or OrtValue, actual: {type(tensor)}")
fetches = OrtValueVector()

if self.training:
self._model.train_step(forward_inputs, fetches)
else:
self._model.eval_step(forward_inputs, fetches)
def _has_np_input(user_inputs):
return any(isinstance(user_input, np.ndarray) for user_input in user_inputs)

def _take_generic_step(forward_inputs):
fetches = OrtValueVector()
if self.training:
self._model.train_step(forward_inputs, fetches)
else:
self._model.eval_step(forward_inputs, fetches)

if len(fetches) == 1:
if is_np_input:
if len(fetches) == 1:
return fetches[0].numpy()

return fetches[0]
return tuple(val.numpy() for val in fetches)

def _take_step_with_ortvalues(forward_inputs):
ort_values = OrtValueVector()
ort_values.reserve(len(forward_inputs))
fetches = OrtValueVector()

for tensor in forward_inputs:
ort_values.push_back(tensor._ortvalue)

if self.training:
self._model.train_step_with_ort_values(ort_values, fetches)
else:
self._model.eval_step_with_ort_values(ort_values, fetches)

if len(fetches) == 1:
return OrtValue(fetches[0])

return tuple(OrtValue(val) for val in fetches)

if _has_np_input(user_inputs):
return _take_generic_step([*user_inputs])

return tuple(val.numpy() for val in fetches) if is_np_input else tuple(fetches)
return _take_step_with_ortvalues(user_inputs)

def train(self, mode: bool = True) -> Module:
"""Sets the Module in training mode.
Expand Down
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/api/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class Optimizer:
"""

def __init__(self, optimizer_uri: str | os.PathLike, module: Module):
self._optimizer = C.Optimizer(os.fspath(optimizer_uri), module._state._state, module._device)
self._optimizer = C.Optimizer(
os.fspath(optimizer_uri), module._state._state, module._device, module._session_options
)

def step(self) -> None:
"""Updates the model parameters based on the computed gradients.
Expand Down
23 changes: 21 additions & 2 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import contextlib
import logging
import os
import pathlib
Expand Down Expand Up @@ -61,6 +62,8 @@ def generate_artifacts(
If None, the current working directory is used.
prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library (str | os.PathLike): The path to the custom op library.
If not specified, no custom op library is used.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
Expand Down Expand Up @@ -121,14 +124,30 @@ def build(self, *inputs_to_loss):
training_model = None
eval_model = None
model_params = None
with onnxblock.base(model):

custom_op_library = extra_options.get("custom_op_library", None)
if custom_op_library is not None:
logging.info("Custom op library provided: %s", custom_op_library)
custom_op_library = pathlib.Path(custom_op_library)

with onnxblock.base(model), onnxblock.custom_op_library(
custom_op_library
) if custom_op_library is not None else contextlib.nullcontext():
_ = training_block(*[output.name for output in model.graph.output])
training_model, eval_model = training_block.to_model_proto()
model_params = training_block.parameters()

def _export_to_ort_format(model_path, output_dir, extra_options):
if extra_options.get("ort_format", False):
convert_onnx_models_to_ort(model_path, output_dir=output_dir, optimization_styles=[OptimizationStyle.Fixed])
custom_op_library = extra_options.get("custom_op_library", None)
if custom_op_library is not None:
custom_op_library = pathlib.Path(custom_op_library)
convert_onnx_models_to_ort(
model_path,
output_dir=output_dir,
custom_op_library_path=custom_op_library,
optimization_styles=[OptimizationStyle.Fixed],
)

if artifact_directory is None:
artifact_directory = pathlib.Path.cwd()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import onnxruntime.training.onnxblock.optim as optim
from onnxruntime.training.onnxblock.blocks import Block
from onnxruntime.training.onnxblock.checkpoint_utils import load_checkpoint_to_model, save_checkpoint
from onnxruntime.training.onnxblock.model_accessor import base, empty_base
from onnxruntime.training.onnxblock.model_accessor import base, custom_op_library, empty_base
from onnxruntime.training.onnxblock.onnxblock import ForwardBlock, TrainingBlock

__all__ = [
Expand All @@ -21,5 +21,6 @@
"load_checkpoint_to_model",
"save_checkpoint",
"base",
"custom_op_library",
"empty_base",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Licensed under the MIT License.

import copy
import os
from typing import List, Optional, Set, Tuple, Union

import onnx

from onnxruntime import SessionOptions
from onnxruntime.capi._pybind_state import GradientGraphBuilder, get_optimized_model


Expand Down Expand Up @@ -66,17 +68,25 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti


def _gradient_model_for(
model: onnx.ModelProto, requires_grad: Set[str], output_names: List[str], loss_name: str
model: onnx.ModelProto,
requires_grad: Set[str],
output_names: List[str],
loss_name: str,
options: Optional[SessionOptions] = None,
) -> onnx.ModelProto:
"""Builds the gradient graph on top of the given input forward only graph."""

builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name)
builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options)
builder.build()
return onnx.load_from_string(builder.get_model())


def build_gradient_graph(
model: onnx.ModelProto, requires_grad: Set[str], frozen_params: Set[str], output_names: Union[List[str], str]
model: onnx.ModelProto,
requires_grad: Set[str],
frozen_params: Set[str],
output_names: Union[List[str], str],
custom_op_library: Optional[str] = None,
) -> Tuple[onnx.ModelProto, onnx.ModelProto]:
"""Prepare the training model and the eval model.
Expand Down Expand Up @@ -106,10 +116,14 @@ def build_gradient_graph(
eval_model = copy.deepcopy(model)
_disable_training_mode(eval_model)

optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad))
options = SessionOptions()
if custom_op_library is not None:
options.register_custom_ops_library(os.fspath(custom_op_library))

optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options))

# Assumption is that the first graph output is the loss output
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0])
gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options)

_reorder_outputs(gradient_model, output_names, requires_grad)

Expand Down
49 changes: 49 additions & 0 deletions orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

import numpy as np
import onnx

import onnxruntime.training.onnxblock._graph_utils as _graph_utils
Expand Down Expand Up @@ -427,3 +428,51 @@ def build(self, cast_input_name: str):
self.base.graph.node.append(cast_node)

return cast_output_name


class Linear(Block):
def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0):
super().__init__()

self._in_features = in_features
self._bias = bias
self._out_features = out_features
self._alpha = alpha
self._beta = beta

def build(self, linear_input_name: str):
# Weight initializer
linear_node_weight_name = _graph_utils.generate_graph_name("linear.weight")

self.base.graph.initializer.append(
onnx.numpy_helper.from_array(
np.random.randn(self._in_features, self._out_features).astype(np.float32), linear_node_weight_name
)
)

linear_node_input_names = [linear_input_name, linear_node_weight_name]

# Bias initializer
if self._bias:
linear_node_bias_name = _graph_utils.generate_graph_name("linear.bias")
self.base.graph.initializer.append(
onnx.numpy_helper.from_array(
np.random.randn(self._out_features).astype(np.float32), linear_node_bias_name
)
)
linear_node_input_names.append(linear_node_bias_name)

linear_node_output_name = _graph_utils.generate_graph_name("linear.output")
linear_node_output_names = [linear_node_output_name]
linear_node = onnx.helper.make_node(
"Gemm",
linear_node_input_names,
linear_node_output_names,
_graph_utils.generate_graph_name("linear"),
alpha=self._alpha,
beta=self._beta,
)

self.base.graph.node.append(linear_node)

return linear_node_output_name
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from __future__ import annotations

import copy
import os
from contextlib import contextmanager
from typing import Optional

import onnx

Expand All @@ -29,10 +31,11 @@ def model(self) -> onnx.ModelProto:
return self._model


# This variable resides in the global namespace.
# These variable resides in the global namespace.
# Different methods can access this global model and manipulate it.
# Its construction and destruction is managed by the base and empty_base contextmanagers
_GLOBAL_ACCESSOR = None
_GLOBAL_CUSTOM_OP_LIBRARY = None


@contextmanager
Expand Down Expand Up @@ -74,7 +77,7 @@ def base(model: onnx.ModelProto):


@contextmanager
def empty_base(opset_version: Optional[int] = None):
def empty_base(opset_version: int | None = None):
"""Registers an empty base model to be manipulated by the onnx blocks.
Example:
Expand All @@ -89,8 +92,7 @@ def empty_base(opset_version: Optional[int] = None):
model_handle.
Args:
opset_version (int, optional): The opset version to use for the model.
Defaults to onnx.defs.onnx_opset_version()
opset_version: The opset version to use for the model. Defaults to onnx.defs.onnx_opset_version()
Returns:
ModelAccessor: The model accessor that contains the modified model.
Expand All @@ -115,3 +117,35 @@ def empty_base(opset_version: Optional[int] = None):
yield _GLOBAL_ACCESSOR
finally:
_GLOBAL_ACCESSOR = None


@contextmanager
def custom_op_library(custom_op_library_path: os.PathLike):
"""Registers the custom op library to be used by the onnx blocks.
Example:
>>> with onnxblock.custom_op_library(custom_op_library_path):
>>> # manipulate the model using blocks
>>> ...
In this example, custom_op_library will register the given input custom op library path to be used
during the model manipulation (gradient graph building and optimization).
Args:
custom_op_library_path: The path to the custom op library.
Returns:
ModelAccessor: The model accessor that contains the modified model.
"""
global _GLOBAL_CUSTOM_OP_LIBRARY # pylint: disable=global-statement # noqa: PLW0603
if _GLOBAL_CUSTOM_OP_LIBRARY is not None:
raise RuntimeError("CustomOp library already set. Cannot set multiple custom op libraries.")

if not os.path.exists(custom_op_library_path):
raise RuntimeError(f"Custom op library path {custom_op_library_path} does not exist.")

_GLOBAL_CUSTOM_OP_LIBRARY = copy.copy(custom_op_library_path) # noqa: PLW0603
try:
yield _GLOBAL_CUSTOM_OP_LIBRARY
finally:
_GLOBAL_CUSTOM_OP_LIBRARY = None
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ def __call__(self, *args, **kwargs):
# The order of model inputs after gradient graph building is: user inputs, model parameters as inputs
# The order of the model outputs is: user outputs, model parameter gradients (in the order of parameter inputs)
self._training_model, self._eval_model = _training_graph_utils.build_gradient_graph(
model,
self._requires_grad,
self._frozen_params,
output,
model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY
)

_training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad)
Expand Down
Loading

0 comments on commit 9889f0f

Please sign in to comment.