diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 6bbe7fdcc1ec7a..1dd417052cf100 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -461,16 +461,21 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const; using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const; using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const; + using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, c10::optional) const; using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const; using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const; POINTWISE_BOXED(add_.Tensor); // just testing + POINTWISE_BOXED(atan2_); + POINTWISE_BOXED(gcd_); + POINTWISE_BOXED(lcm_); VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); + VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule>)); VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule)); VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule)); diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 8a25a350f950be..1b179a505e9a91 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -46,6 +46,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(absolute); OP_DECOMPOSE(absolute_); OP_DECOMPOSE(arctan2); + OP_DECOMPOSE(arctan2_); OP_DECOMPOSE(argsort); OP_DECOMPOSE(avg_pool1d); OP_DECOMPOSE(adaptive_max_pool1d); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index b0c21421b8b41c..fc28228045f1e9 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3762,23 +3762,18 @@ def test_op_has_batch_rule(self, device, dtype, op): 'addmm', 'addmv', 'addr', - 'atan2', 'baddbmm', 'clamp', 'conj_physical', 'cumprod', 'cumsum', - 'div', - 'div', 'floor_divide', 'fmod', - 'gcd', 'heaviside', 'hypot', 'igamma', 'igammac', 'index_copy', - 'lcm', 'ldexp', 'lerp', 'neg', diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index b25732dfcf5e28..4952f2745b6d72 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -38,7 +38,6 @@ "aten::align_to.ellipsis_idx", "aten::alpha_dropout", "aten::alpha_dropout_", - "aten::arctan2_", "aten::argwhere", "aten::bilinear", "aten::can_cast",