Skip to content

Commit

Permalink
Fix Build Issue (pytorch#2694)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2694

Failure here https://www.internalfb.com/intern/testinfra/diagnostics/4785074838821245.844425001247266.1717686112

https://www.internalfb.com/tasks?t=191623363

Reviewed By: jianyuh, amylittleyang

Differential Revision: D58224260

fbshipit-source-id: 67455e6085f5d2136f2a06b44f7eefcadc8a5edd
  • Loading branch information
ayaIbrah authored and facebook-github-bot committed Jun 6, 2024
1 parent 900b05b commit 1c3e3bf
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions fbgemm_gpu/experimental/gen_ai/src/attention/gqa_attn_splitk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -664,20 +664,23 @@ __global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1)
for (int vec = 0; vec < KV_NUM_VECS; ++vec) {
auto* smem_s = reinterpret_cast<__nv_bfloat162*>(
smem_staging_ + vec * KV_NUM_ELS_PER_DEQ);
if (USE_FP8) {
const auto k_deq = dequantize_packed_fp8(k_vals_[vec], k_scales);
if (!USE_FP8) {
const auto k_deq =
dequantize_permuted_int4(k_vals_[vec], k_scales);
#pragma unroll
for (int i = 0; i < KV_NUM_ELS_PER_DEQ / 2; i++) {
smem_s[i] = k_deq.vals[i];
}
} else {
const auto k_deq =
dequantize_permuted_int4(k_vals_[vec], k_scales);
}
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
else {
const auto k_deq = dequantize_packed_fp8(k_vals_[vec], k_scales);
#pragma unroll
for (int i = 0; i < KV_NUM_ELS_PER_DEQ / 2; i++) {
smem_s[i] = k_deq.vals[i];
}
}
#endif
}
}
// Load BF16 values to K fragment
Expand Down Expand Up @@ -1047,18 +1050,22 @@ __global__ void __launch_bounds__(kThreadsPerWarp* kSplitKWarpsPerBlock, 1)
const auto v_vals_ = reinterpret_cast<uint32_t*>(v_vals)[vec];
auto* smem_s = reinterpret_cast<__nv_bfloat162*>(
smem_staging_ + smem_d * KV_NUM_ELS_PER_DEQ);
if (USE_FP8) {
const auto v_deq = dequantize_packed_fp8(v_vals_, v_scales);
if (!USE_FP8) {
const auto v_deq = dequantize_permuted_int4(v_vals_, v_scales);
#pragma unroll
for (int i = 0; i < KV_NUM_ELS_PER_DEQ / 2; i++) {
smem_s[i] = v_deq.vals[i];
}
} else {
const auto v_deq = dequantize_permuted_int4(v_vals_, v_scales);
}
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
else {
const auto v_deq = dequantize_packed_fp8(v_vals_, v_scales);
#pragma unroll
for (int i = 0; i < KV_NUM_ELS_PER_DEQ / 2; i++) {
smem_s[i] = v_deq.vals[i];
}
}
#endif
}
} else {
// Need to fill zeros to avoid nan
Expand Down

0 comments on commit 1c3e3bf

Please sign in to comment.