Skip to content

Commit

Permalink
fix curand (dmlc#3077)
Browse files Browse the repository at this point in the history
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
BarclayII and jermainewang authored Jul 2, 2021
1 parent a0390dd commit 4e74dc8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/random/random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
#pragma omp parallel for
for (int i = 0; i < omp_get_max_threads(); ++i) {
RandomEngine::ThreadLocal()->SetSeed(seed);
}
#ifdef DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed + GetThreadId())));
#endif // DGL_USE_CUDA
auto* thr_entry = CUDAThreadEntry::ThreadLocal();
if (!thr_entry->curand_gen) {
CURAND_CALL(curandCreateGenerator(&thr_entry->curand_gen, CURAND_RNG_PSEUDO_DEFAULT));
}
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(
thr_entry->curand_gen,
static_cast<uint64_t>(seed)));
#endif // DGL_USE_CUDA
});

DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
Expand Down

0 comments on commit 4e74dc8

Please sign in to comment.