forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Reland] Adds an aten::_ops namespace with unambiguous function names (…
…pytorch#59018) Summary: Pull Request resolved: pytorch#59018 Fixes pytorch#58044. This PR: - adds `ATEN_FN(op)` and `ATEN_FN2(op, overload)` macros that resolve to an non-overloaded function in aten::_ops that calls the desired operator (without default arguments). The motivation for this is two-fold: 1) Using aten operators with templates is hard if the operator is overloaded (e.g. add.Tensor and add.Scalar). 2) Method-only operators require special handling; pointers-to-method are different from function pointers. `ATEN_FN2(add_, Tensor)` returns a function instead of a method. There is some interesting behavior for out= operations. `ATEN_FN2(sin, "out")` gives a function that is *faithful* to the schema; that is, the order of arguments is exactly what it looks like in the schema. This makes it so that you can directly register `ATEN_FN2(sin,"out")` (or a function wrapping it using the same signature) as an override for a DispatchKey. Test Plan: - New tests that ATEN_FN2 works on function and method-only operators - New test that ATEN_FN works - New test that ATEN_FN macro returns a "faithful" function. Codegen output: Operators.h and Operators.cpp are both here: https://gist.github.com/zou3519/c2c6a900410b571f0d7d127019ca5175 Reviewed By: bdhirsh Differential Revision: D28721206 Pulled By: zou3519 fbshipit-source-id: a070017f98e8f4038cb0c64be315eef45d264217
- Loading branch information
1 parent
8805093
commit 970096b
Showing
7 changed files
with
207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include <ATen/Operators.h> | ||
|
||
namespace at { namespace _ops { | ||
|
||
Tensor & requires_grad_(Tensor & self, bool requires_grad) { | ||
self.requires_grad_(requires_grad); | ||
return self; | ||
} | ||
|
||
${definitions} | ||
|
||
}} // namespace at::_ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#pragma once | ||
|
||
// ${generated_comment} | ||
|
||
#include <ATen/Functions.h> | ||
#include <ATen/Tensor.h> | ||
|
||
// Extension writers: do you write wrapper functions? Are you frustrated with | ||
// resolving overloads of operators? Are you frustrated with dealing with | ||
// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no | ||
// further, this is the utility for you. | ||
// | ||
// Given an operator schema: aten::op.overload(... | ||
// | ||
// Use ATEN_FN2(op, overload) to get a *function* version of the operator | ||
// that is guaranteed to not be overloaded. This means that you can safely | ||
// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args. | ||
// | ||
// Given an operator schema without an overload name: aten::op(... | ||
// | ||
// Use ATEN_FN(op) to get an unambiguous *function* version of the operator. | ||
// | ||
// There is some interesting behavior for out= operations. | ||
// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema; | ||
// that is, the order of arguments is exactly what it looks like in the schema. | ||
|
||
#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload | ||
#define ATEN_FN(op_name) at::_ops::op_name | ||
|
||
// WARNING: Please do not call any of the ops in the _ops namespace directly. | ||
// Use the ATEN_FN macros. We do not guarantee stability of the naming | ||
// scheme for the functions in at::_ops | ||
namespace at { namespace _ops { | ||
|
||
// NB: We are forced to special case requires_grad_. This is because all | ||
// of the auto-generated inplace method signatures in TensorMethods.h are | ||
// codegen'ed to return Tensor&, but requires_grad_ has a `manual_cpp_binding` | ||
// with a different signature that returns `const Tensor&`. | ||
// | ||
// Eventually, the plan is to kill Tensor& from all C++ signatures and use | ||
// const Tensor&. When that happens, we can remove this special case and just | ||
// let the codegen handle it. | ||
TORCH_API Tensor & requires_grad_(Tensor & self, bool requires_grad); | ||
|
||
${declarations} | ||
|
||
}} // namespace at::_ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include <gtest/gtest.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/Operators.h> | ||
|
||
using namespace at; | ||
|
||
template <class F, F Func, class Output, class... Args> | ||
Output pass_through_wrapper(Args... args) { | ||
return Func(std::forward<Args>(args)...); | ||
} | ||
|
||
TEST(OperatorsTest, TestFunctionDecltype) { | ||
Tensor a = at::randn({5, 5}); | ||
Tensor b = at::randn({5, 5}); | ||
auto expected = a * b; | ||
|
||
auto result = pass_through_wrapper< | ||
decltype(&ATEN_FN2(mul, Tensor)), &ATEN_FN2(mul, Tensor), | ||
Tensor, const Tensor&, const Tensor&>(a, b); | ||
ASSERT_TRUE(at::allclose(result, a * b)); | ||
} | ||
|
||
TEST(OperatorsTest, TestMethodOnlyDecltype) { | ||
Tensor a = at::randn({5, 5}); | ||
Tensor b = at::randn({5, 5}); | ||
auto expected = a * b; | ||
|
||
// NB: add_ overloads are guaranteed to be method-only | ||
// because that is how the tensor API works. | ||
auto& result = pass_through_wrapper< | ||
decltype(&ATEN_FN2(mul_, Tensor)), &ATEN_FN2(mul_, Tensor), | ||
Tensor&, Tensor&, const Tensor&>(a, b); | ||
ASSERT_TRUE(at::allclose(result, expected)); | ||
} | ||
|
||
TEST(OperatorsTest, Test_ATEN_FN) { | ||
Tensor a = at::rand({5, 5}); | ||
|
||
auto result = pass_through_wrapper< | ||
decltype(&ATEN_FN(sin)), &ATEN_FN(sin), | ||
Tensor, const Tensor&>(a); | ||
ASSERT_TRUE(at::allclose(result, a.sin())); | ||
} | ||
|
||
TEST(OperatorsTest, TestOutVariantIsFaithful) { | ||
Tensor a = at::rand({5, 5}); | ||
Tensor b = at::empty({5, 5}); | ||
|
||
auto& result = pass_through_wrapper< | ||
decltype(&ATEN_FN2(sin, out)), &ATEN_FN2(sin, out), | ||
Tensor&, const Tensor&, Tensor&>(a, b); | ||
ASSERT_TRUE(at::allclose(result, a.sin())); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters