Skip to content

Commit

Permalink
[Reland] Adds an aten::_ops namespace with unambiguous function names (
Browse files Browse the repository at this point in the history
…pytorch#59018)

Summary:
Pull Request resolved: pytorch#59018

Fixes pytorch#58044.

This PR:
- adds `ATEN_FN(op)` and `ATEN_FN2(op, overload)` macros that resolve to
an non-overloaded function in aten::_ops that calls the desired operator
(without default arguments).

The motivation for this is two-fold:
1) Using aten operators with templates is hard if the operator is
overloaded (e.g. add.Tensor and add.Scalar).
2) Method-only operators require special handling; pointers-to-method
are different from function pointers. `ATEN_FN2(add_, Tensor)` returns
a function instead of a method.

There is some interesting behavior for out= operations.
`ATEN_FN2(sin, "out")` gives a function that is *faithful* to the schema;
that is, the order of arguments is exactly what it looks like in the
schema. This makes it so that you can directly register
`ATEN_FN2(sin,"out")` (or a function wrapping it using the same signature)
as an override for a DispatchKey.

Test Plan:
- New tests that ATEN_FN2 works on function and method-only operators
- New test that ATEN_FN works
- New test that ATEN_FN macro returns a "faithful" function.

Codegen output:
Operators.h and Operators.cpp are both here:
https://gist.github.com/zou3519/c2c6a900410b571f0d7d127019ca5175

Reviewed By: bdhirsh

Differential Revision: D28721206

Pulled By: zou3519

fbshipit-source-id: a070017f98e8f4038cb0c64be315eef45d264217
  • Loading branch information
zou3519 authored and facebook-github-bot committed Jun 2, 2021
1 parent 8805093 commit 970096b
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ genrule(
"aten/src/ATen/Functions.cpp",
"aten/src/ATen/RedispatchFunctions.h",
"aten/src/ATen/RedispatchFunctions.cpp",
"aten/src/ATen/Operators.h",
"aten/src/ATen/Operators.cpp",
"aten/src/ATen/NativeFunctions.h",
"aten/src/ATen/MetaFunctions.h",
"aten/src/ATen/core/TensorBody.h",
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/templates/Operators.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include <ATen/Operators.h>

namespace at { namespace _ops {

Tensor & requires_grad_(Tensor & self, bool requires_grad) {
self.requires_grad_(requires_grad);
return self;
}

${definitions}

}} // namespace at::_ops
47 changes: 47 additions & 0 deletions aten/src/ATen/templates/Operators.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

// ${generated_comment}

#include <ATen/Functions.h>
#include <ATen/Tensor.h>

// Extension writers: do you write wrapper functions? Are you frustrated with
// resolving overloads of operators? Are you frustrated with dealing with
// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
// further, this is the utility for you.
//
// Given an operator schema: aten::op.overload(...
//
// Use ATEN_FN2(op, overload) to get a *function* version of the operator
// that is guaranteed to not be overloaded. This means that you can safely
// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
//
// Given an operator schema without an overload name: aten::op(...
//
// Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
//
// There is some interesting behavior for out= operations.
// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
// that is, the order of arguments is exactly what it looks like in the schema.

#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload
#define ATEN_FN(op_name) at::_ops::op_name

// WARNING: Please do not call any of the ops in the _ops namespace directly.
// Use the ATEN_FN macros. We do not guarantee stability of the naming
// scheme for the functions in at::_ops
namespace at { namespace _ops {

// NB: We are forced to special case requires_grad_. This is because all
// of the auto-generated inplace method signatures in TensorMethods.h are
// codegen'ed to return Tensor&, but requires_grad_ has a `manual_cpp_binding`
// with a different signature that returns `const Tensor&`.
//
// Eventually, the plan is to kill Tensor& from all C++ signatures and use
// const Tensor&. When that happens, we can remove this special case and just
// let the codegen handle it.
TORCH_API Tensor & requires_grad_(Tensor & self, bool requires_grad);

${declarations}

}} // namespace at::_ops
1 change: 1 addition & 0 deletions aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/operators_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/math_kernel_test.cpp
Expand Down
54 changes: 54 additions & 0 deletions aten/src/ATen/test/operators_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <gtest/gtest.h>

#include <ATen/ATen.h>
#include <ATen/Operators.h>

using namespace at;

template <class F, F Func, class Output, class... Args>
Output pass_through_wrapper(Args... args) {
return Func(std::forward<Args>(args)...);
}

TEST(OperatorsTest, TestFunctionDecltype) {
Tensor a = at::randn({5, 5});
Tensor b = at::randn({5, 5});
auto expected = a * b;

auto result = pass_through_wrapper<
decltype(&ATEN_FN2(mul, Tensor)), &ATEN_FN2(mul, Tensor),
Tensor, const Tensor&, const Tensor&>(a, b);
ASSERT_TRUE(at::allclose(result, a * b));
}

TEST(OperatorsTest, TestMethodOnlyDecltype) {
Tensor a = at::randn({5, 5});
Tensor b = at::randn({5, 5});
auto expected = a * b;

// NB: add_ overloads are guaranteed to be method-only
// because that is how the tensor API works.
auto& result = pass_through_wrapper<
decltype(&ATEN_FN2(mul_, Tensor)), &ATEN_FN2(mul_, Tensor),
Tensor&, Tensor&, const Tensor&>(a, b);
ASSERT_TRUE(at::allclose(result, expected));
}

TEST(OperatorsTest, Test_ATEN_FN) {
Tensor a = at::rand({5, 5});

auto result = pass_through_wrapper<
decltype(&ATEN_FN(sin)), &ATEN_FN(sin),
Tensor, const Tensor&>(a);
ASSERT_TRUE(at::allclose(result, a.sin()));
}

TEST(OperatorsTest, TestOutVariantIsFaithful) {
Tensor a = at::rand({5, 5});
Tensor b = at::empty({5, 5});

auto& result = pass_through_wrapper<
decltype(&ATEN_FN2(sin, out)), &ATEN_FN2(sin, out),
Tensor&, const Tensor&, Tensor&>(a, b);
ASSERT_TRUE(at::allclose(result, a.sin()));
}
1 change: 1 addition & 0 deletions aten/tools/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ VALGRIND=${VALGRIND:=ON}
./NamedTensor_test
./cpu_generator_test
./vmap_test
./operators_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
Expand Down
90 changes: 90 additions & 0 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,87 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
return f'm.def({cpp_string(str(f.func))});\n'


def _num_leading_spaces(line: str) -> int:
return len(line) - len(line.lstrip())


# Unindents all lines in code. Each line gets unindented the same amount;
# that amount is equal to the smallest number of leading spaces across all lines
def deindent(code: str) -> str:
lines = code.split('\n')
min_leading_spaces = min(map(_num_leading_spaces, lines))
lines = [line[min_leading_spaces:] for line in lines]
return '\n'.join(lines)


# Generates Operators.h and Operators.cpp.
# These provide macros that, given an operator and overload name, allow users
# to access an "un-overloaded" function version of the operator. This
# is useful for extension writers who want to (1) want to decltype the operator
# and (2) don't want to worry about method-only operators.
@dataclass(frozen=True)
class ComputeOperators:
target: Union[
Literal[Target.DECLARATION],
Literal[Target.DEFINITION]
]

@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
# NB: requires_grad is the only exception to the rule because
# its const correctness is questionable.
if str(f.func.name) in set(['requires_grad_']):
return None

if self.target is Target.DECLARATION:
return self.gen_declaration(f)
if self.target is Target.DEFINITION:
return self.gen_definition(f)
else:
assert_never(self.target)

# NB: This must be synchronized with the naming scheme in
# aten/src/ATen/templates/Operators.h
# Given a function schema "aten::op.overload(...)",
# If there is no overload name, this returns f"{op}"
# If there is an overload name, this returns f"{op}_{overload}"
def unambiguous_function_name(self, f: NativeFunction) -> str:
base_name = str(f.func.name.name)
overload_name = f.func.name.overload_name
if overload_name:
return f'{base_name}_{overload_name}'
return base_name

def gen_declaration(self, f: NativeFunction) -> str:
unambiguous_name = self.unambiguous_function_name(f)
sig = DispatcherSignature.from_schema(f.func)
return f"TORCH_API {sig.decl(unambiguous_name)};"

def most_faithful_name(self, f: NativeFunction) -> str:
sig_group = CppSignatureGroup.from_native_function(f, method=False)
sig = sig_group.most_faithful_signature()
return sig.name()

def invocation(self, f: NativeFunction) -> str:
faithful_op_name = self.most_faithful_name(f)
args = tuple(arg.name for arg in dispatcher.arguments(f.func))
# Method only
if Variant.function not in f.variants:
return f"{args[0]}.{faithful_op_name}({', '.join(args[1:])})"
return f"at::{faithful_op_name}({', '.join(args)})"

def gen_definition(self, f: NativeFunction) -> str:
unambiguous_name = self.unambiguous_function_name(f)
args = dispatcher.arguments(f.func)
sig = DispatcherSignature.from_schema(f.func)

return deindent(f"""\
{sig.defn(unambiguous_name)} {{
return {self.invocation(f)};
}}\
""")


# Generates Function.cpp and Function.h. These files provide the
# functional public C++ API, and the scaffolding to call into
# the dispatcher from these functions. See also compute_tensor_method.
Expand Down Expand Up @@ -992,6 +1073,15 @@ def make_file_manager(install_dir: str) -> FileManager:
'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
})

cpu_fm.write('Operators.cpp', lambda: {
'definitions': list(mapMaybe(ComputeOperators(
Target.DEFINITION), native_functions)),
})
cpu_fm.write('Operators.h', lambda: {
'declarations': list(mapMaybe(ComputeOperators(
Target.DECLARATION), native_functions)),
})

cpu_fm.write('Functions.h', lambda: {
'function_declarations': list(mapMaybe(ComputeFunction(
Target.DECLARATION, static_dispatch_backend_index=static_dispatch_idx, is_redispatching_fn=False), native_functions)),
Expand Down

0 comments on commit 970096b

Please sign in to comment.