forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add core of c10::complex [resubmit] (pytorch#36626)
Summary: Pull Request resolved: pytorch#36626 This reverts commit 9216c67. Test Plan: Imported from OSS Differential Revision: D21140441 Pulled By: anjali411 fbshipit-source-id: 488530088e2ff87dc27e70d21ace88ff2967e7ab
- Loading branch information
1 parent
6ac0f67
commit 20328f6
Showing
11 changed files
with
1,046 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include <c10/test/util/complex_test_common.h> | ||
|
||
__global__ void test_thrust_kernel() { | ||
// thrust conversion | ||
{ | ||
constexpr float num1 = float(1.23); | ||
constexpr float num2 = float(4.56); | ||
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).real() == num1); | ||
assert(c10::complex<float>(thrust::complex<float>(num1, num2)).imag() == num2); | ||
} | ||
{ | ||
constexpr double num1 = double(1.23); | ||
constexpr double num2 = double(4.56); | ||
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).real() == num1); | ||
assert(c10::complex<double>(thrust::complex<double>(num1, num2)).imag() == num2); | ||
} | ||
// thrust assignment | ||
auto tup = assignment::one_two_thrust(); | ||
assert(std::get<c10::complex<double>>(tup).real() == double(1)); | ||
assert(std::get<c10::complex<double>>(tup).imag() == double(2)); | ||
assert(std::get<c10::complex<float>>(tup).real() == float(1)); | ||
assert(std::get<c10::complex<float>>(tup).imag() == float(2)); | ||
} | ||
|
||
__global__ void test_std_functions_kernel() { | ||
assert(std::abs(c10::complex<float>(3, 4)) == float(5)); | ||
assert(std::abs(c10::complex<double>(3, 4)) == double(5)); | ||
|
||
assert(std::abs(std::arg(c10::complex<float>(0, 1)) - PI / 2) < 1e-6); | ||
assert(std::abs(std::arg(c10::complex<double>(0, 1)) - PI / 2) < 1e-6); | ||
|
||
assert(std::abs(c10::polar(float(1), float(PI / 2)) - c10::complex<float>(0, 1)) < 1e-6); | ||
assert(std::abs(c10::polar(double(1), double(PI / 2)) - c10::complex<double>(0, 1)) < 1e-6); | ||
} | ||
|
||
__global__ void test_reinterpret_cast() { | ||
std::complex<float> z(1, 2); | ||
c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z); | ||
assert(zz.real() == float(1)); | ||
assert(zz.imag() == float(2)); | ||
|
||
std::complex<double> zzz(1, 2); | ||
c10::complex<double> zzzz = *reinterpret_cast<c10::complex<double>*>(&zzz); | ||
assert(zzzz.real() == double(1)); | ||
assert(zzzz.imag() == double(2)); | ||
} | ||
|
||
int safeDeviceCount() { | ||
int count; | ||
cudaError_t err = cudaGetDeviceCount(&count); | ||
if (err == cudaErrorInsufficientDriver || err == cudaErrorNoDevice) { | ||
return 0; | ||
} | ||
return count; | ||
} | ||
|
||
#define SKIP_IF_NO_GPU() \ | ||
do { \ | ||
if (safeDeviceCount() == 0) { \ | ||
return; \ | ||
} \ | ||
} while(0) | ||
|
||
TEST(DeviceTests, ThrustConversion) { | ||
SKIP_IF_NO_GPU(); | ||
ASSERT_EQ(cudaGetLastError(), cudaSuccess); | ||
cudaDeviceSynchronize(); | ||
test_thrust_kernel<<<1, 1>>>(); | ||
cudaDeviceSynchronize(); | ||
ASSERT_EQ(cudaGetLastError(), cudaSuccess); | ||
} | ||
|
||
TEST(DeviceTests, StdFunctions) { | ||
SKIP_IF_NO_GPU(); | ||
cudaDeviceSynchronize(); | ||
test_std_functions_kernel<<<1, 1>>>(); | ||
cudaDeviceSynchronize(); | ||
ASSERT_EQ(cudaGetLastError(), cudaSuccess); | ||
} | ||
|
||
TEST(DeviceTests, ReinterpretCast) { | ||
SKIP_IF_NO_GPU(); | ||
cudaDeviceSynchronize(); | ||
test_reinterpret_cast<<<1, 1>>>(); | ||
cudaDeviceSynchronize(); | ||
ASSERT_EQ(cudaGetLastError(), cudaSuccess); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#include <c10/test/util/complex_test_common.h> |
Oops, something went wrong.