Skip to content

Commit

Permalink
[functorch] disable C++ Function under functorch transforms (pytorch#…
Browse files Browse the repository at this point in the history
…103957)

Fixes pytorch#102720

Pull Request resolved: pytorch#103957
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Jun 23, 2023
1 parent ec24f1e commit 47894bb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/FuncTorchTLS.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct TORCH_API FuncTorchTLSBase {
virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;

virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
virtual void checkSupportsCppAutogradFunction() const = 0;
virtual void checkSupportsInplaceRequiresGrad() const = 0;
virtual void checkSupportsRetainGrad() const = 0;
};
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/functorch/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ class FuncTorchTLS : public FuncTorchTLSBase {
return 0;
}

void checkSupportsCppAutogradFunction() const override {
TORCH_CHECK(
dynamicLayerStack.empty(),
"cannot use C++ torch::autograd::Function with functorch transforms (vmap, grad, vjp, etc)");
}

void checkSupportsInplaceRequiresGrad() const override {
TORCH_CHECK(dynamicLayerStack.empty() || allow_inplace_requires_grad_,
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
Expand Down
23 changes: 23 additions & 0 deletions test/cpp_extensions/identity.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include <torch/extension.h>
#include <torch/torch.h>

using namespace torch::autograd;

class Identity : public Function<Identity> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
return input;
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
return {grad_outputs[0]};
}
};

torch::Tensor identity(torch::Tensor input) {
return Identity::apply(input);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("identity", &identity, "identity");
}
15 changes: 15 additions & 0 deletions test/test_cpp_extensions_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,21 @@ def test_custom_compound_op_autograd(self):
for fast_mode in (True, False):
gradcheck(torch.ops.my.add, [a, b], eps=1e-2, fast_mode=fast_mode)

def test_custom_functorch_error(self):
# Test that a custom C++ Function raises an error under functorch transforms
identity_m = torch.utils.cpp_extension.load(
name="identity",
sources=["cpp_extensions/identity.cpp"],
)

t = torch.randn(3, requires_grad=True)

msg = r"cannot use C\+\+ torch::autograd::Function with functorch"
with self.assertRaisesRegex(RuntimeError, msg):
torch.func.vmap(identity_m.identity)(t)

with self.assertRaisesRegex(RuntimeError, msg):
torch.func.grad(identity_m.identity)(t)

if __name__ == "__main__":
common.run_tests()
8 changes: 8 additions & 0 deletions torch/csrc/autograd/custom_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ template <class T>
template <typename X, typename... Args>
auto Function<T>::apply(Args&&... args)
-> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>> {
const auto& functorch_tls = at::functorch::functorchTLSAccessor();
if (functorch_tls) {
// Function support for functorch is handled in Python.
// Here we are dealing with a (C++) Function, which is not supported.
// Let's raise an error instead of being silently incorrect.
functorch_tls->checkSupportsCppAutogradFunction();
}

std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list input_vars;
Expand Down

0 comments on commit 47894bb

Please sign in to comment.