Skip to content

Commit

Permalink
kompute : update ggml_vk_supports_op to fix false pos/neg
Browse files Browse the repository at this point in the history
test-backend-ops hit assertion failures in ggml_vk_graph_compute because
of ops we do not yet support. Some of the checks have to be made more
restrictive because of features that were added to llama.cpp.

We also claimed to not support no-op operations on certain data types,
even though they are actually supported on all data types. There are now
243 passsing tests, instead of 150 without the fixes for false
negatives. This also fixes complaints during LLM inference about
unsupported NONE operations for the output tensor.
  • Loading branch information
cebtenzzre committed Sep 26, 2024
1 parent 132f701 commit 8078783
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions ggml/src/ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1421,43 +1421,51 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
}

static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
return true; // noop -> dst type does not matter
}

switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
break;
default:
return false;
return false; // dst type not supported
}

switch (op->op) {
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
return true;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
return ggml_is_contiguous(op->src[0]);
return ggml_nelements(op) % 4 == 0 && ggml_is_contiguous(op->src[0]);
default:
;
}
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_SCALE:
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
float max_bias;
memcpy(&max_bias, (const float *)op->op_params + 1, sizeof(float));
return max_bias == 0.0f;
case GGML_OP_ROPE:
return true;
case GGML_OP_DUP:
case GGML_OP_CPY:
return op->src[2] == nullptr;
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
Expand Down

0 comments on commit 8078783

Please sign in to comment.