Skip to content

Commit

Permalink
Add mutex to THC random number generator (#6527)
Browse files Browse the repository at this point in the history
* Add mutex to THC random number generator

* Add test for CUDA RNG multithread

* fix lint

* Rename gen_state to state and remove unnecessary mutex lock

* Remove RNG test from cpp_extensions

* Add CUDA RNG test to libtorch

* Build test_rng only if CUDA exists

* Move test to aten/src/ATen/test/

* Separate ATen build and test, and run ATen test in CI test phase

* Don't test ATen in ASAN build

* Fix bug in ATen scalar_test

* Fix bug in ATen native_test

* Add FIXME to some CUDA tests in scalar_tensor_test

* Valgrind doesn't work well with CUDA, seed the CPU and CUDA RNG separately instead
  • Loading branch information
yf225 authored Apr 18, 2018
1 parent c25f097 commit e089849
Show file tree
Hide file tree
Showing 25 changed files with 167 additions and 93 deletions.
9 changes: 6 additions & 3 deletions .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ fi

python setup.py install

# Test ATen
# Add the ATen test binaries so that they won't be git clean'ed away
git add -f aten/build/src/ATen/test

# Testing ATen install
if [[ "$BUILD_ENVIRONMENT" != *cuda* ]]; then
echo "Testing ATen"
time tools/run_aten_tests.sh
echo "Testing ATen install"
time tools/test_aten_install.sh
fi

# Test C FFI plugins
Expand Down
8 changes: 8 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ fi

time python test/run_test.py --verbose

# Test ATen
if [[ "$BUILD_ENVIRONMENT" != *asan* ]]; then
echo "Testing ATen"
TORCH_LIB_PATH=$(python -c "import site; print(site.getsitepackages()[0])")/torch/lib
ln -s "$TORCH_LIB_PATH"/libATen.so aten/build/src/ATen/libATen.so
aten/tools/run_tests.sh aten/build
fi

rm -rf ninja

echo "Installing torchvision at branch master"
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.h>
#include <THC/THCApply.cuh>
#include <THC/THCNumerics.cuh>

Expand All @@ -21,8 +22,8 @@ THCGenerator* THCRandom_getGenerator(THCState* state);
namespace {
std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen) {
auto gen_ = THCRandom_getGenerator(at::globalContext().thc_state);
uint64_t offset = THAtomicAddLong(&gen_->philox_seed_offset, 1);
return std::make_pair(gen_->initial_seed, offset);
uint64_t offset = THAtomicAddLong(&gen_->state.philox_seed_offset, 1);
return std::make_pair(gen_->state.initial_seed, offset);
}

template <typename scalar_t>
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ if(NOT NO_CUDA)
target_link_libraries(integer_divider_test ATen)
endif()

if(NOT NO_CUDA)
cuda_add_executable(cuda_rng_test cuda_rng_test.cpp)
target_link_libraries(cuda_rng_test ATen)
endif()

if (CUDNN_FOUND)
add_executable(cudnn_test cudnn_test.cpp)
target_link_libraries(cudnn_test ATen)
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/atest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ void trace() {

TEST_CASE( "atest", "[]" ) {

manual_seed(123);
manual_seed(123, at::Backend::CPU);
manual_seed(123, at::Backend::CUDA);

auto foo = rand(CPU(kFloat), {12,6});
REQUIRE(foo.data<float>() == foo.toFloatData());
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/test/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,13 @@ static void test(Type & type) {
}

TEST_CASE( "basic tests CPU", "[cpu]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

test(CPU(kFloat));
}

TEST_CASE( "basic tests GPU", "[cuda]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CUDA);

if(at::hasCUDA()) {
test(CUDA(kFloat));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/broadcast_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace at;

TEST_CASE( "broadcast", "[]" ) {

manual_seed(123);
manual_seed(123, at::Backend::CPU);

Type & T = CPU(kFloat);

Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/test/cuda_rng_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"

#include "ATen/ATen.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include <thread>

void makeRandomNumber() {
cudaSetDevice(std::rand() % 2);
auto x = at::CUDA(at::kFloat).randn({1000});
}

void testCudaRNGMultithread() {
auto threads = std::vector<std::thread>();
for (auto i = 0; i < 1000; i++) {
threads.emplace_back(makeRandomNumber);
}
for (auto& t : threads) {
t.join();
}
};

TEST_CASE( "CUDA RNG test", "[cuda]" ) {
SECTION( "multithread" )
testCudaRNGMultithread();
}
2 changes: 1 addition & 1 deletion aten/src/ATen/test/cudnn_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using namespace at;
using namespace at::native;

TEST_CASE( "cudnn", "[cuda]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CUDA);

#if CUDNN_VERSION < 7000
auto handle = getCudnnHandle();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/dlconvertor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace at;

TEST_CASE( "dlconvertor", "[cpu]" ) {

manual_seed(123);
manual_seed(123, at::Backend::CPU);

INFO( "convert ATen to DLTensor" );

Expand Down
10 changes: 5 additions & 5 deletions aten/src/ATen/test/native_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ void test(Type & T, Type & AccT) {
auto ct1 = randn(T, {3, 4});
auto ct2 = randn(T, {3, 4});
auto t1 = randn(T.toBackend(Backend::CPU), {3, 4});
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(ct2), "not implemented");
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(t1), "not implemented");
REQUIRE_THROWS_WITH(t1._standard_gamma_grad(ct2), "CUDA Backend");
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(ct2), Catch::Contains("not implemented"));
REQUIRE_THROWS_WITH(ct1._standard_gamma_grad(t1), Catch::Contains("not implemented"));
REQUIRE_THROWS_WITH(t1._standard_gamma_grad(ct2), Catch::Contains("CUDA Backend"));
}
}

Expand All @@ -189,13 +189,13 @@ void test(Type & T, Type & AccT) {
}

TEST_CASE( "native test CPU", "[cpu]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

test(CPU(kFloat), CPU(kDouble));
}

TEST_CASE( "native test CUDA", "[cuda]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CUDA);

if (at::hasCUDA()) {
test(CUDA(kFloat), CUDA(kDouble));
Expand Down
8 changes: 5 additions & 3 deletions aten/src/ATen/test/scalar_tensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ void test(Type &T) {
if (t.numel() != 0) {
REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
REQUIRE(t.sum(0).equal(T.tensor({0})));
if (!T.is_cuda()) { // FIXME: out of range exception in CUDA
REQUIRE(t.sum(0).equal(T.tensor({0})));
}
}

// reduce (with dimension argument and with 2 return arguments)
Expand Down Expand Up @@ -273,13 +275,13 @@ void test(Type &T) {
}

TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

test(CPU(kFloat));
}

TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CUDA);

if (at::hasCUDA()) {
test(CUDA(kFloat));
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/test/scalar_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void test_overflow() {

TEST_CASE( "scalar test", "[]" ) {

manual_seed(123);
manual_seed(123, at::Backend::CPU);
manual_seed(123, at::Backend::CUDA);

Scalar what = 257;
Scalar bar = 3.0;
Expand All @@ -83,7 +84,7 @@ TEST_CASE( "scalar test", "[]" ) {
REQUIRE_NOTHROW(gen.seed());
auto && C = at::globalContext();
if(at::hasCUDA()) {
auto & CUDAFloat = C.getType(Backend::CPU,ScalarType::Float);
auto & CUDAFloat = C.getType(Backend::CUDA,ScalarType::Float);
auto t2 = zeros(CUDAFloat, {4,4});
cout << &t2 << "\n";
cout << "AFTER GET TYPE " << &CUDAFloat << "\n";
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/tbb_init_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void test(int given_num_threads) {
}

int main() {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

test(-1);
std::thread t1(test, -1);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/test_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace at;

TEST_CASE( "parallel", "[cpu]" ) {

manual_seed(123);
manual_seed(123, at::Backend::CPU);
set_num_threads(1);

Tensor a = rand(CPU(at::kFloat), {1,3});
Expand Down
9 changes: 5 additions & 4 deletions aten/src/ATen/test/test_seed.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

#include "ATen/ATen.h"

void manual_seed(uint64_t seed) {
at::Generator & cpu_gen = at::globalContext().defaultGenerator(at::Backend::CPU);
cpu_gen.manualSeed(seed);
if (at::hasCUDA()) {
void manual_seed(uint64_t seed, at::Backend backend) {
if (backend == at::Backend::CPU) {
at::Generator & cpu_gen = at::globalContext().defaultGenerator(at::Backend::CPU);
cpu_gen.manualSeed(seed);
} else if (backend == at::Backend::CUDA && at::hasCUDA()) {
at::Generator & cuda_gen = at::globalContext().defaultGenerator(at::Backend::CUDA);
cuda_gen.manualSeed(seed);
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/undefined_tensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
using namespace at;

TEST_CASE( "undefined tensor test", "[]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

// mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
Tensor und;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/wrapdim_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace at;

TEST_CASE( "wrapdim test", "[]" ) {
manual_seed(123);
manual_seed(123, at::Backend::CPU);

Type & T = CPU(kFloat);

Expand Down
19 changes: 19 additions & 0 deletions aten/src/THC/THCGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef THC_GENERATOR_INC
#define THC_GENERATOR_INC

#include <mutex>

typedef struct THCGeneratorState {
struct curandStateMtgp32* gen_states;
struct mtgp32_kernel_params *kernel_params;
int initf;
uint64_t initial_seed;
int64_t philox_seed_offset;
} THCGeneratorState;

struct THCGenerator {
std::mutex mutex; /* mutex for using this generator */
THCGeneratorState state;
};

#endif
43 changes: 25 additions & 18 deletions aten/src/THC/THCTensorRandom.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "THCTensorRandom.h"
#include "THCGenerator.h"

#include <random>
#include <curand.h>
Expand All @@ -11,15 +12,16 @@ void createGeneratorState(THCGenerator* gen, uint64_t seed);
/* Frees memory allocated during setup. */
void destroyGenerator(THCState *state, THCGenerator* gen)
{
if (gen->gen_states)
std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->state.gen_states)
{
THCudaCheck(THCudaFree(state, gen->gen_states));
gen->gen_states = NULL;
THCudaCheck(THCudaFree(state, gen->state.gen_states));
gen->state.gen_states = NULL;
}
if (gen->kernel_params)
if (gen->state.kernel_params)
{
THCudaCheck(THCudaFree(state, gen->kernel_params));
gen->kernel_params = NULL;
THCudaCheck(THCudaFree(state, gen->state.kernel_params));
gen->state.kernel_params = NULL;
}
}

Expand All @@ -39,11 +41,12 @@ void THCRandom_init(THCState* state, int devices, int current_device)
std::random_device rd;
for (int i = 0; i < rng_state->num_devices; ++i)
{
rng_state->gen[i].initf = 0;
rng_state->gen[i].initial_seed = createSeed(rd);
rng_state->gen[i].philox_seed_offset = 0;
rng_state->gen[i].gen_states = NULL;
rng_state->gen[i].kernel_params = NULL;
new (&rng_state->gen[i].mutex) std::mutex();
rng_state->gen[i].state.initf = 0;
rng_state->gen[i].state.initial_seed = createSeed(rd);
rng_state->gen[i].state.philox_seed_offset = 0;
rng_state->gen[i].state.gen_states = NULL;
rng_state->gen[i].state.kernel_params = NULL;
}
}

Expand Down Expand Up @@ -74,18 +77,20 @@ static THCGenerator* THCRandom_rawGenerator(THCState* state)
THCGenerator* THCRandom_getGenerator(THCState* state)
{
THCGenerator* gen = THCRandom_rawGenerator(state);
if (gen->initf == 0)
std::lock_guard<std::mutex> lock(gen->mutex);
if (gen->state.initf == 0)
{
initializeGenerator(state, gen);
createGeneratorState(gen, gen->initial_seed);
gen->initf = 1;
createGeneratorState(gen, gen->state.initial_seed);
gen->state.initf = 1;
}
return gen;
}

struct curandStateMtgp32* THCRandom_generatorStates(struct THCState* state)
{
return THCRandom_getGenerator(state)->gen_states;
THCGenerator* gen = THCRandom_getGenerator(state);
return gen->state.gen_states;
}

/* Random seed */
Expand All @@ -109,8 +114,9 @@ uint64_t THCRandom_seedAll(THCState* state)
void THCRandom_manualSeed(THCState* state, uint64_t seed)
{
THCGenerator* gen = THCRandom_rawGenerator(state);
gen->initial_seed = seed;
if (gen->initf) {
std::lock_guard<std::mutex> lock(gen->mutex);
gen->state.initial_seed = seed;
if (gen->state.initf) {
createGeneratorState(gen, seed);
}
}
Expand All @@ -130,5 +136,6 @@ void THCRandom_manualSeedAll(THCState* state, uint64_t seed)
/* Get the initial seed */
uint64_t THCRandom_initialSeed(THCState* state)
{
return THCRandom_getGenerator(state)->initial_seed;
THCGenerator* gen = THCRandom_getGenerator(state);
return gen->state.initial_seed;
}
Loading

0 comments on commit e089849

Please sign in to comment.