Skip to content

Commit

Permalink
Use smaller internal struct to store population RNG (XORWOW) in CUDAH…
Browse files Browse the repository at this point in the history
…IP backend

* Write struct to definitions
* Use new class for allocating memory and struct fields
* Reimplemented population RNG preamble and postamble in ``BackendSIMT::getPopulationRNG`` using new destructor mechanism to copy from and to internal struct
  • Loading branch information
neworderofjamie committed Jan 15, 2025
1 parent 90a90e2 commit b1357c8
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 68 deletions.
10 changes: 3 additions & 7 deletions include/genn/backends/cuda/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -191,10 +188,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -234,6 +227,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
10 changes: 3 additions & 7 deletions include/genn/backends/hip/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
AtomicOperation op = AtomicOperation::ADD,
AtomicMemSpace memSpace = AtomicMemSpace::GLOBAL) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -182,10 +179,6 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
virtual std::unique_ptr<Runtime::ArrayBase> createArray(const Type::ResolvedType &type, size_t count,
VarLocation location, bool uninitialized) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code to allocate variable with a size known at runtime
virtual void genLazyVariableDynamicAllocation(CodeStream &os,
const Type::ResolvedType &type, const std::string &name, VarLocation loc,
Expand Down Expand Up @@ -225,6 +218,9 @@ class BACKEND_EXPORT Backend : public BackendCUDAHIP
return m_ChosenDevice.totalConstMem - getPreferences<Preferences>().constantCacheOverhead;
}

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const final;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const final;

Expand Down
24 changes: 13 additions & 11 deletions include/genn/genn/code_generator/backendCUDAHIP.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
m_RandPrefix(randPrefix), m_CCLPrefix(cclPrefix)
{}

//--------------------------------------------------------------------------
// Declared virtuals
//--------------------------------------------------------------------------

//--------------------------------------------------------------------------
// CodeGenerator::BackendSIMT virtuals
//--------------------------------------------------------------------------
Expand All @@ -87,16 +83,15 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const final;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const final;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const final;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const final;

//! Get type of population RNG
virtual Type::ResolvedType getPopulationRNGType() const final;

//! Generate a preamble to add substitution name for population RNG
virtual void addPopulationRNG(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const final;

//--------------------------------------------------------------------------
// CodeGenerator::BackendBase virtuals
//--------------------------------------------------------------------------
Expand All @@ -118,6 +113,10 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
virtual void genFreeMemPreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;
virtual void genStepTimeFinalisePreamble(CodeStream &os, const ModelSpecMerged &modelMerged) const final;

//! Create array of backend-specific population RNGs (if they are initialised on host this will occur here)
/*! \param count number of RNGs required*/
virtual std::unique_ptr<GeNN::Runtime::ArrayBase> createPopulationRNG(size_t count) const final;

//! Generate code for pushing a variable with a size known at runtime to the 'device'
virtual void genLazyVariableDynamicPush(CodeStream &os,
const Type::ResolvedType &type, const std::string &name,
Expand Down Expand Up @@ -173,6 +172,9 @@ class GENN_EXPORT BackendCUDAHIP : public BackendSIMT
//! Get the safe amount of constant cache we can use
virtual size_t getChosenDeviceSafeConstMemBytes() const = 0;

//! Get internal type population RNG gets loaded into
virtual Type::ResolvedType getPopulationRNGInternalType() const = 0;

//! Get library of RNG functions to use
virtual const EnvironmentLibrary::Library &getRNGFunctions(const Type::ResolvedType &precision) const = 0;

Expand Down
8 changes: 2 additions & 6 deletions include/genn/genn/code_generator/backendSIMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,9 @@ class GENN_EXPORT BackendSIMT : public BackendBase
//! For SIMT backends which initialize RNGs on device, initialize population RNG with specified seed and sequence
virtual void genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const = 0;

//! Generate a preamble to add substitution name for population RNG
virtual std::string genPopulationRNGPreamble(CodeStream &os, const std::string &globalRNG) const = 0;
//! Add $(_rng) to environment based on $(_rng_internal) field with any initialisers and destructors required
virtual void addPopulationRNG(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const = 0;

//! If required, generate a postamble for population RNG
/*! For example, in OpenCL, this is used to write local RNG state back to global memory*/
virtual void genPopulationRNGPostamble(CodeStream &os, const std::string &globalRNG) const = 0;

//! Generate code to skip ahead local copy of global RNG
virtual std::string genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const = 0;

Expand Down
15 changes: 5 additions & 10 deletions src/genn/backends/cuda/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,6 @@ std::string Backend::getAtomic(const Type::ResolvedType &type, AtomicOperation o
}
}
//--------------------------------------------------------------------------
Type::ResolvedType Backend::getPopulationRNGType() const
{
return CURandState;
}
//--------------------------------------------------------------------------
std::unique_ptr<GeNN::Runtime::StateBase> Backend::createState(const Runtime::Runtime &runtime) const
{
return std::make_unique<State>(runtime);
Expand All @@ -369,11 +364,6 @@ std::unique_ptr<Runtime::ArrayBase> Backend::createArray(const Type::ResolvedTyp
return std::make_unique<Array>(type, count, location, uninitialized);
}
//--------------------------------------------------------------------------
std::unique_ptr<Runtime::ArrayBase> Backend::createPopulationRNG(size_t count) const
{
return createArray(CURandState, count, VarLocation::DEVICE, false);
}
//--------------------------------------------------------------------------
void Backend::genLazyVariableDynamicAllocation(CodeStream &os, const Type::ResolvedType &type, const std::string &name,
VarLocation loc, const std::string &countVarName) const
{
Expand Down Expand Up @@ -550,6 +540,11 @@ std::string Backend::getNVCCFlags() const
return nvccFlags;
}
//--------------------------------------------------------------------------
Type::ResolvedType Backend::getPopulationRNGInternalType() const
{
return CURandState;
}
//--------------------------------------------------------------------------
const EnvironmentLibrary::Library &Backend::getRNGFunctions(const Type::ResolvedType &precision) const
{
if(precision == Type::Float) {
Expand Down
15 changes: 5 additions & 10 deletions src/genn/backends/hip/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,6 @@ std::string Backend::getAtomic(const Type::ResolvedType &type, AtomicOperation o
}
}
//--------------------------------------------------------------------------
Type::ResolvedType Backend::getPopulationRNGType() const
{
return HIPRandState;
}
//--------------------------------------------------------------------------
std::unique_ptr<GeNN::Runtime::StateBase> Backend::createState(const Runtime::Runtime &runtime) const
{
return std::make_unique<State>(runtime);
Expand All @@ -361,11 +356,6 @@ std::unique_ptr<Runtime::ArrayBase> Backend::createArray(const Type::ResolvedTyp
return std::make_unique<Array>(type, count, location, uninitialized);
}
//--------------------------------------------------------------------------
std::unique_ptr<Runtime::ArrayBase> Backend::createPopulationRNG(size_t count) const
{
return createArray(HIPRandState, count, VarLocation::DEVICE, false);
}
//--------------------------------------------------------------------------
void Backend::genLazyVariableDynamicAllocation(CodeStream &os, const Type::ResolvedType &type, const std::string &name,
VarLocation loc, const std::string &countVarName) const
{
Expand Down Expand Up @@ -550,6 +540,11 @@ std::string Backend::getHIPCCFlags() const
#endif
}
//--------------------------------------------------------------------------
Type::ResolvedType Backend::getPopulationRNGInternalType() const
{
return HIPRandState;
}
//--------------------------------------------------------------------------
const EnvironmentLibrary::Library &Backend::getRNGFunctions(const Type::ResolvedType &precision) const
{
if(precision == Type::Float) {
Expand Down
68 changes: 59 additions & 9 deletions src/genn/genn/code_generator/backendCUDAHIP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ const EnvironmentLibrary::Library backendFunctions = {
{"atomic_or", {Type::ResolvedType::createFunction(Type::Void, {Type::Uint32.createPointer(), Type::Uint32}), "atomicOr($(0), $(1))"}},
};

const Type::ResolvedType XORWowStateInternal = Type::ResolvedType::createValue("XORWowStateInternal", 24, false, nullptr, true);

//--------------------------------------------------------------------------
// Timer
//--------------------------------------------------------------------------
Expand Down Expand Up @@ -234,24 +236,58 @@ void BackendCUDAHIP::genSharedMemBarrier(CodeStream &os) const
//--------------------------------------------------------------------------
void BackendCUDAHIP::genPopulationRNGInit(CodeStream &os, const std::string &globalRNG, const std::string &seed, const std::string &sequence) const
{
os << getRandPrefix() << "_init(" << seed << ", " << sequence << ", 0, &" << globalRNG << ");" << std::endl;
// Initialise full curandState/hiprandState object
os << getRandPrefix() << "State rngState;" << std::endl;
os << getRandPrefix() << "_init(" << seed << ", " << sequence << ", 0, &rngState);" << std::endl;

// Copy useful components into internal object
os << globalRNG << ".d = rngState.d;" << std::endl;
for(int i = 0; i < 5; i++) {
os << globalRNG << ".v[" << i << "] = rngState.v[" << i << "];" << std::endl;
}
}
//--------------------------------------------------------------------------
std::string BackendCUDAHIP::genPopulationRNGPreamble(CodeStream &, const std::string &globalRNG) const
std::string BackendCUDAHIP::genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const
{
return "&" + globalRNG;
// Skipahead RNG
os << getRandPrefix() << "StatePhilox4_32_10_t localRNG = d_rng;" << std::endl;
os << "skipahead_sequence((unsigned long long)" << sequence << ", &localRNG);" << std::endl;
return "localRNG";
}
//--------------------------------------------------------------------------
void BackendCUDAHIP::genPopulationRNGPostamble(CodeStream&, const std::string&) const
Type::ResolvedType BackendCUDAHIP::getPopulationRNGType() const
{
return XORWowStateInternal;
}
//--------------------------------------------------------------------------
std::string BackendCUDAHIP::genGlobalRNGSkipAhead(CodeStream &os, const std::string &sequence) const
void BackendCUDAHIP::addPopulationRNG(EnvironmentGroupMergedField<NeuronUpdateGroupMerged> &env) const
{
// Skipahead RNG
os << getRandPrefix() << "StatePhilox4_32_10_t localRNG = d_rng;" << std::endl;
os << "skipahead_sequence((unsigned long long)" << sequence << ", &localRNG);" << std::endl;
return "localRNG";
// Generate initialiser code to create CURandState from internal RNG state
std::stringstream init;
init << getRandPrefix() << "State rngState;" << std::endl;

// Copy useful components into full object
init << "rngState.d = $(_rng_internal).d;" << std::endl;
for(int i = 0; i < 5; i++) {
init << "rngState.v[" << i << "] = $(_rng_internal).v[" << i << "];" << std::endl;
}

// Zero box-muller flag
init << "rngState.boxmuller_flag = 0;" << std::endl;

// Generate destructor code to copy CURandState back into internal RNG state
std::stringstream destroy;

// Copy useful components into internal object
destroy << "$(_rng_internal).d = rngState.d;" << std::endl;
for(int i = 0; i < 5; i++) {
destroy << "$(_rng_internal).v[" << i << "] = rngState.v[" << i << "];" << std::endl;
}

// Add alias with initialiser and destructor statements
env.add(getPopulationRNGInternalType(), "_rng", "rngState",
{env.addInitialiser(init.str())},
{env.addDestructor(destroy.str())});
}
//--------------------------------------------------------------------------
void BackendCUDAHIP::genNeuronUpdate(CodeStream &os, ModelSpecMerged &modelMerged, BackendBase::MemorySpaces &memorySpaces,
Expand Down Expand Up @@ -1033,6 +1069,15 @@ void BackendCUDAHIP::genDefinitionsPreamble(CodeStream &os, const ModelSpecMerge
// b) Just calls abort_, not killing kernels with correct exit code
// c) undefs the assert in <cassert> which actually works
os << "#include <cassert>" << std::endl;
os << std::endl;

os << "struct XORWowStateInternal" << std::endl;
{
CodeStream::Scope b(os);
os << "unsigned int d;" << std::endl;
os << "unsigned int v[5];" << std::endl;
}
os << ";" << std::endl;

os << std::endl;
os << "template<typename RNG>" << std::endl;
Expand Down Expand Up @@ -1349,6 +1394,11 @@ void BackendCUDAHIP::genStepTimeFinalisePreamble(CodeStream &os, const ModelSpec
}
}
//--------------------------------------------------------------------------
std::unique_ptr<Runtime::ArrayBase> BackendCUDAHIP::createPopulationRNG(size_t count) const
{
return createArray(XORWowStateInternal, count, VarLocation::DEVICE, false);
}
//--------------------------------------------------------------------------
void BackendCUDAHIP::genLazyVariableDynamicPush(CodeStream &os,
const Type::ResolvedType &type, const std::string &name,
VarLocation loc, const std::string &countVarName) const
Expand Down
13 changes: 5 additions & 8 deletions src/genn/genn/code_generator/backendSIMT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,11 +498,14 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM
CodeStream::Scope b(groupEnv.getStream());

// Add population RNG field
groupEnv.addField(getPopulationRNGType().createPointer(), "_rng", "rng",
groupEnv.addField(getPopulationRNGType().createPointer(), "_rng_internal", "rng",
[](const auto &runtime, const auto &g, size_t) { return runtime.getArray(g, "rng"); },
ng.getVarIndex(batchSize, VarAccessDim::BATCH | VarAccessDim::ELEMENT, "$(id)"));
// **TODO** for OCL do genPopulationRNGPreamble(os, popSubs, "group->rng[" + ng.getVarIndex(batchSize, VarAccessDuplication::DUPLICATE, "$(id)") + "]") in initialiser

// Add population RNG preamble to initialise _rng from _rng_internal
addPopulationRNG(groupEnv);

// Generate neuron update
ng.generateNeuronUpdate(
*this, groupEnv, batchSize,
// Emit true spikes
Expand All @@ -515,12 +518,6 @@ void BackendSIMT::genNeuronUpdateKernel(EnvironmentExternalBase &env, ModelSpecM
{
genEmitEvent(env, ng, sg.getIndex(), false);
});

// Copy local stream back to local
// **TODO** postamble for OCL
//if(ng.getArchetype().isSimRNGRequired()) {
// genPopulationRNGPostamble(neuronEnv.getStream(), rng);
//}
}

genSharedMemBarrier(groupEnv.getStream());
Expand Down

0 comments on commit b1357c8

Please sign in to comment.