Skip to content

Commit

Permalink
Add core of c10::complex [resubmit] (pytorch#36626)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#36626

This reverts commit 9216c67.

Test Plan: Imported from OSS

Differential Revision: D21140441

Pulled By: anjali411

fbshipit-source-id: 488530088e2ff87dc27e70d21ace88ff2967e7ab
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Apr 24, 2020
1 parent 6ac0f67 commit 20328f6
Show file tree
Hide file tree
Showing 11 changed files with 1,046 additions and 2 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ cc_library(
"c10/cuda/impl/*.h",
"c10/macros/*.h",
"c10/util/*.h",
"c10/util/*.hpp",
]) + [
"c10/macros/cmake_macros.h",
"c10/cuda/impl/cuda_cmake_macros.h",
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/NumericUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <complex>
#include <type_traits>
#include <c10/util/BFloat16.h>
#include <c10/util/Complex.h>
#include <c10/util/LegacyComplex.h>
#include <c10/macros/Macros.h>

namespace at {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp)

list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_apply_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_stream_test.cpp
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/complex_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <gtest/gtest.h>
#include <c10/util/Complex.h>
#include <c10/util/LegacyComplex.h>

template<typename T, typename int_t>
static void TestBinaryOpsForIntType(T real, T img, int_t num) {
Expand Down
88 changes: 88 additions & 0 deletions aten/src/ATen/test/cuda_complex_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <c10/test/util/complex_test_common.h>

__global__ void test_thrust_kernel() {
// thrust conversion
{
constexpr float num1 = float(1.23);
constexpr float num2 = float(4.56);
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).real() == num1);
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).imag() == num2);
}
{
constexpr double num1 = double(1.23);
constexpr double num2 = double(4.56);
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).real() == num1);
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).imag() == num2);
}
// thrust assignment
auto tup = assignment::one_two_thrust();
assert(std::get<c10::complex<double>>(tup).real() == double(1));
assert(std::get<c10::complex<double>>(tup).imag() == double(2));
assert(std::get<c10::complex<float>>(tup).real() == float(1));
assert(std::get<c10::complex<float>>(tup).imag() == float(2));
}

__global__ void test_std_functions_kernel() {
assert(std::abs(c10::complex<float>(3, 4)) == float(5));
assert(std::abs(c10::complex<double>(3, 4)) == double(5));

assert(std::abs(std::arg(c10::complex<float>(0, 1)) - PI / 2) < 1e-6);
assert(std::abs(std::arg(c10::complex<double>(0, 1)) - PI / 2) < 1e-6);

assert(std::abs(c10::polar(float(1), float(PI / 2)) - c10::complex<float>(0, 1)) < 1e-6);
assert(std::abs(c10::polar(double(1), double(PI / 2)) - c10::complex<double>(0, 1)) < 1e-6);
}

__global__ void test_reinterpret_cast() {
std::complex<float> z(1, 2);
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
assert(zz.real() == float(1));
assert(zz.imag() == float(2));

std::complex<double> zzz(1, 2);
c10::complex<double> zzzz = *reinterpret_cast<c10::complex<double>*>(&zzz);
assert(zzzz.real() == double(1));
assert(zzzz.imag() == double(2));
}

int safeDeviceCount() {
int count;
cudaError_t err = cudaGetDeviceCount(&count);
if (err == cudaErrorInsufficientDriver || err == cudaErrorNoDevice) {
return 0;
}
return count;
}

#define SKIP_IF_NO_GPU() \
do { \
if (safeDeviceCount() == 0) { \
return; \
} \
} while(0)

TEST(DeviceTests, ThrustConversion) {
SKIP_IF_NO_GPU();
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
cudaDeviceSynchronize();
test_thrust_kernel<<<1, 1>>>();
cudaDeviceSynchronize();
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
}

TEST(DeviceTests, StdFunctions) {
SKIP_IF_NO_GPU();
cudaDeviceSynchronize();
test_std_functions_kernel<<<1, 1>>>();
cudaDeviceSynchronize();
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
}

TEST(DeviceTests, ReinterpretCast) {
SKIP_IF_NO_GPU();
cudaDeviceSynchronize();
test_reinterpret_cast<<<1, 1>>>();
cudaDeviceSynchronize();
ASSERT_EQ(cudaGetLastError(), cudaSuccess);
}

3 changes: 3 additions & 0 deletions aten/tools/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ fi
if [[ -x ./cuda_tensor_interop_test ]]; then
./cuda_tensor_interop_test
fi
if [[ -x ./cuda_complex_test ]]; then
./cuda_complex_test
fi
if [ "$VALGRIND" == "ON" ]
then
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic --gtest_filter='-*CUDA'
Expand Down
1 change: 1 addition & 0 deletions c10/test/util/complex_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include <c10/test/util/complex_test_common.h>
Loading

0 comments on commit 20328f6

Please sign in to comment.