Skip to content

Commit

Permalink
[dynamo][numpy] Support ndarray methods (pytorch#97537)
Browse files Browse the repository at this point in the history
 This PR adds universal support for ndarray methods. After pytorch#100839 each `NumpyNdarrayVariable` should wrap a `torch.Tensor`. This PR adds a `numpy_method_wrapper` which converts the `torch.Tensor` to `torch_np.ndarray` and then call the numpy ndarray method. Then we also try to return a `torch.Tensor` (return as-is if the value is not ndarray-like)

Pull Request resolved: pytorch#97537
Approved by: https://github.com/ezyang
  • Loading branch information
larryliu0820 authored and pytorchmergebot committed Jun 12, 2023
1 parent 18f203a commit 2eac8bd
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
24 changes: 24 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,30 @@ def test_return_multiple_numpy_ndarray(x):
a = x.numpy()
return a.T, a.imag, a.real

@requires_numpy_pytorch_interop
@make_test
def test_ndarray_method(x):
a = x.numpy()
return a.copy()

@requires_numpy_pytorch_interop
@make_test
def test_ndarray_transpose(x):
a = x.numpy()
return a.transpose(0, 1)

@requires_numpy_pytorch_interop
@make_test
def test_ndarray_reshape(x):
a = x.numpy()
return a.reshape([1, a.size])

@requires_numpy_pytorch_interop
@make_test
def test_ndarray_methods_returning_scalar(x):
a = x.numpy()
return a.max(axis=0), a.all(axis=0)


def global_func_with_default_tensor_args(
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
Expand Down Expand Up @@ -746,7 +747,8 @@ def append_prefix_insts():
if (
stack_values
and all(
not isinstance(v, UnspecializedPythonVariable) for v in stack_values
not isinstance(v, (UnspecializedPythonVariable, NumpyNdarrayVariable))
for v in stack_values
)
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
Expand Down
19 changes: 19 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,25 @@ def numpy_attr_wrapper(obj, name):
return numpy_to_tensor(out)


class numpy_method_wrapper:
"""Convert obj from torch.Tensor to torch_np.ndarray and call method. Then convert result back to torch.Tensor."""

def __init__(self, method: str):
self.method = method
self.__name__ = "wrapped_" + self.method

def __repr__(self):
return f"<Wrapped method <original {self.method}>>"

def __call__(self, *args, **kwargs):
obj = args[0]
if isinstance(obj, torch.Tensor):
obj = torch_np.ndarray(obj)
method_callable = getattr(obj, self.method)
out = method_callable(*args[1:], **kwargs)
return numpy_to_tensor(out)


def defake(x):
if not isinstance(x, FakeTensor):
return x
Expand Down
18 changes: 17 additions & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,23 @@ def call_method(
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented(f"numpy_ndarray.{name}()")
options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())])
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
from ..utils import numpy_method_wrapper

result = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
numpy_method_wrapper(name),
*proxy_args_kwargs([self] + list(args), kwargs),
),
example_value=None,
**options,
)

return result


class UnspecializedPythonVariable(TensorVariable):
Expand Down

0 comments on commit 2eac8bd

Please sign in to comment.