diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp index 210a834300e0f..6ca1c9e490787 100644 --- a/ggml/src/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute.cpp @@ -1421,6 +1421,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) { } static bool ggml_vk_supports_op(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + return true; // noop -> dst type does not matter + } + switch (op->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: @@ -1428,36 +1437,35 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { case GGML_TYPE_Q4_1: break; default: - return false; + return false; // dst type not supported } switch (op->op) { + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: + return true; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: - return ggml_is_contiguous(op->src[0]); + return ggml_nelements(op) % 4 == 0 && ggml_is_contiguous(op->src[0]); default: ; } break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_SCALE: case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_NORM: + float max_bias; + memcpy(&max_bias, (const float *)op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; case GGML_OP_ROPE: - return true; - case GGML_OP_DUP: - case GGML_OP_CPY: + return op->src[2] == nullptr; case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: