Skip to content

Commit

Permalink
[complex32] sqrt-rsqrt : cuda (pytorch#77490)
Browse files Browse the repository at this point in the history
Follows pytorch#74537

cc @kshitij12345!

Pull Request resolved: pytorch#77490
Approved by: https://github.com/ngimel
  • Loading branch information
khushi-411 authored and pytorchmergebot committed Jun 6, 2022
1 parent 530dcc2 commit e7b96ad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
16 changes: 9 additions & 7 deletions aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/NumericUtils.h>
#include <ATen/OpMathType.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/core/Scalar.h>
#include <c10/util/complex.h>
Expand Down Expand Up @@ -103,18 +104,18 @@ void rsqrt_kernel_cuda(TensorIteratorBase& iter) {
const T one = T{1};
return one / std::sqrt(x);
}); // rsqrt_string
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/rsqrt_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, rsqrt_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "rsqrt_cuda", [&]() {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
return rsqrt_wrapper(a);
using opmath_t = at::opmath_type<scalar_t>;
return rsqrt_wrapper(static_cast<opmath_t>(a));
});
});
#endif
Expand All @@ -141,17 +142,18 @@ void sqrt_kernel_cuda(TensorIteratorBase& iter) {
T sqrt_kernel(T x) {
return std::sqrt(x);
}); // sqrt_string
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
jitted_gpu_kernel<
/*name=*/sqrt_name,
/*return_dtype=*/scalar_t,
/*common_dtype=*/scalar_t,
/*arity=*/1>(iter, sqrt_string);
});
#else
AT_DISPATCH_COMPLEX_TYPES(common_dtype, "sqrt_cuda", [&]() {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return std::sqrt(a);
using opmath_t = at::opmath_type<scalar_t>;
return ::sqrt(static_cast<opmath_t>(a));
});
});
#endif
Expand Down
33 changes: 27 additions & 6 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10617,7 +10617,6 @@ def error_inputs_mean(op_info, device, **kwargs):
supports_fwgrad_bwgrad=True,
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
assert_autodiffed=True,
decorators=[
DecorateInfo(
Expand Down Expand Up @@ -15631,36 +15630,58 @@ def error_inputs_mean(op_info, device, **kwargs):
ref=lambda x: np.reciprocal(np.sqrt(x)),
domain=(0, None),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.half: 5e-2}),),
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
dtypes=(torch.cfloat, torch.cdouble)),
# AssertionError: Tensor-likes are not close!
# Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
# Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=(torch.chalf,)),
)),
UnaryUfuncInfo('sqrt',
ref=np.sqrt,
supports_sparse=True,
domain=(0, None),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
assert_autodiffed=True,
supports_forward_ad=True,
supports_sparse_csr=True,
supports_fwgrad_bwgrad=True,
decorators=(precisionOverride({torch.bfloat16: 7e-2}),),
decorators=(
precisionOverride({torch.bfloat16: 7e-2}),
DecorateInfo(
toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
'TestUnaryUfuncs', 'test_reference_numerics_large'),
),
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/47358
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
active_if=IS_MACOS),
# Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=[torch.bfloat16]),
dtypes=(torch.bfloat16,)),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_consistency',
dtypes=(torch.chalf,)),
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_inplace',
dtypes=(torch.chalf,)),
# RuntimeError: "nonzero_cuda" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_sparse_csr_unary_out',
dtypes=(torch.chalf,)),
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
DecorateInfo(unittest.expectedFailure, 'TestSparseCSR', 'test_zero_to_zero_correspondence_unary',
dtypes=(torch.chalf,)),
)),
UnaryUfuncInfo('square',
ref=np.square,
Expand Down

0 comments on commit e7b96ad

Please sign in to comment.