Skip to content

Commit

Permalink
ROCm warp size fix [AMD official] (state-spaces#405)
Browse files Browse the repository at this point in the history
* ROCM conditional compilation fix

* compile flag for warp size

* use #define to set warp size

* fix brace bug

* minor style fix

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
amoskvic and ajassani authored Jun 19, 2024
1 parent c71f86c commit c809b2b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
13 changes: 5 additions & 8 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif

}
Expand All @@ -536,12 +536,12 @@ template<typename input_t, typename weight_t>
void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {

#ifndef USE_ROCM
constexpr int warp_size = 32;
#define warp_size 32
#else
constexpr int warp_size = rocprim::warp_size();
#define warp_size ROCM_WARP_SIZE
#endif

if (warp_size == 32) {
#if warp_size == 32
if (params.seqlen <= 128) {
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
Expand All @@ -553,9 +553,7 @@ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#ifdef USE_ROCM
else {
#else
if (params.seqlen <= 256) {
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
Expand All @@ -565,6 +563,5 @@ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#endif
}
11 changes: 4 additions & 7 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ template<typename input_t, typename weight_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {

#ifndef USE_ROCM
constexpr int warp_size = 32;
#define warp_size 32
#else
constexpr int warp_size = rocprim::warp_size();
#define warp_size ROCM_WARP_SIZE
#endif

if (warp_size == 32) {
#if warp_size == 32
if (params.seqlen <= 128) {
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
Expand All @@ -368,9 +368,7 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
} else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#ifdef USE_ROCM
else {
#else
if (params.seqlen <= 256) {
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
Expand All @@ -380,6 +378,5 @@ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
} else {
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#endif
}
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,23 @@ def append_nvcc_threads(nvcc_extra_args):

if HIP_BUILD:

try:
# set warp size based on gcn architecure
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
# radeon
warp_size = 32
else:
# instinct
warp_size = 64
except AttributeError as e:
# fall back to crude method to set warp size
device_name = torch.cuda.get_device_properties(0).name
if 'instinct' in device_name.lower():
warp_size = 64
else:
warp_size = 32

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
"nvcc": [
Expand All @@ -209,6 +226,7 @@ def append_nvcc_threads(nvcc_extra_args):
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-DCK_FMHA_FWD_FAST_EXP2=1",
"-fgpu-flush-denormals-to-zero",
f"-DROCM_WARP_SIZE={warp_size}"
]
+ cc_flag,
}
Expand Down

0 comments on commit c809b2b

Please sign in to comment.