From adf2474b105e826684d60823f2e0bfaa140dbb03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 22 Sep 2024 09:34:52 +0200 Subject: [PATCH] CUDA: enable Gemma FA for HIP/Pascal (llama/9581) --- ggml/src/ggml-cuda.cu | 16 ++++++++-------- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index f940511980e..bf21c643d3a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV: return true; - case GGML_OP_FLASH_ATTN_EXT: -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; -#else + case GGML_OP_FLASH_ATTN_EXT: { + if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + return true; + } if (op->src[0]->ne[0] == 128) { return true; } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } - return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc; + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + } case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index f28a19d40b3..83e5589a1cc 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);