Skip to content

Commit

Permalink
vmap: simple inplace batch rule (pytorch#113513)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#113513
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Nov 21, 2023
1 parent f66add9 commit 1a3dbf5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/functorch/BatchRulesBinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,16 +461,21 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, c10::optional<c10::string_view>) const;
using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;

POINTWISE_BOXED(add_.Tensor); // just testing
POINTWISE_BOXED(atan2_);
POINTWISE_BOXED(gcd_);
POINTWISE_BOXED(lcm_);
VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::add_, const Scalar&, const Scalar&>));
VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::sub_, const Scalar&>));
VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::sub_, const Scalar&, const Scalar&>));
VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::mul_>));
VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::mul_, const Scalar&>));
VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::div_>));
VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, c10::optional<c10::string_view>>));
VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>));
VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>));
VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>));
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(absolute);
OP_DECOMPOSE(absolute_);
OP_DECOMPOSE(arctan2);
OP_DECOMPOSE(arctan2_);
OP_DECOMPOSE(argsort);
OP_DECOMPOSE(avg_pool1d);
OP_DECOMPOSE(adaptive_max_pool1d);
Expand Down
5 changes: 0 additions & 5 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3762,23 +3762,18 @@ def test_op_has_batch_rule(self, device, dtype, op):
'addmm',
'addmv',
'addr',
'atan2',
'baddbmm',
'clamp',
'conj_physical',
'cumprod',
'cumsum',
'div',
'div',
'floor_divide',
'fmod',
'gcd',
'heaviside',
'hypot',
'igamma',
'igammac',
'index_copy',
'lcm',
'ldexp',
'lerp',
'neg',
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"aten::align_to.ellipsis_idx",
"aten::alpha_dropout",
"aten::alpha_dropout_",
"aten::arctan2_",
"aten::argwhere",
"aten::bilinear",
"aten::can_cast",
Expand Down

0 comments on commit 1a3dbf5

Please sign in to comment.