Skip to content

Commit

Permalink
Faster MoE inference (#112)
Browse files Browse the repository at this point in the history
* multi_sdd: WIP

* multi_sdd: CPU works

* multi_add: CUDA

* multi_add: simplify

* multi_add: Metal

* Metal: speed up mul_mat_id

For the Granite-1B MoE model PP-512 goes from
156 t/s to 890 t/s, so nearly a 6X speedup!

---------

Co-authored-by: Iwan Kawrakow <[email protected]>
  • Loading branch information
ikawrakow and Kawrakow authored Oct 31, 2024
1 parent 5ad6439 commit 52874c5
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 34 deletions.
6 changes: 6 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ extern "C" {
GGML_OP_GROUP_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,
GGML_OP_MULTI_ADD,

GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
Expand Down Expand Up @@ -930,6 +931,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_multi_add(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_experts);

// dst = a
// view(dst, nb1, nb2, nb3, offset) += b
// return dst
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD:
ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_MULTI_ADD:
ggml_cuda_op_multi_add(ctx, dst);
break;
case GGML_OP_ACC:
ggml_cuda_op_acc(ctx, dst);
break;
Expand Down Expand Up @@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
// Changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}

if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
Expand Down Expand Up @@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
case GGML_OP_MULTI_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}

static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
int64_t k = ne0*ne1;
if (i >= k) {
return;
}
int i1 = i / ne0;
int i0 = i % ne0;
float * result = (float *)(dst + i1*nb1);
const float * s = (const float *)(src0 + i1*nb01) + i0;
if (nused == 1) {
result[i0] = s[0];
} else {
float sum = s[0] + s[ne0];
for (int j = 2; j < nused; ++j) sum += s[j*ne0];
result[i0] = sum;
}
}

static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

Expand Down Expand Up @@ -218,6 +237,23 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) {
int64_t k = ne0 * ne1;
const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE;
multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, src0, dst);
}

void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
GGML_ASSERT(dst->nb[0] == sizeof(float));
int nused = dst->op_params[0];
GGML_ASSERT(nused >= 1);
const char * src0 = (const char *)dst->src[0]->data;
cudaStream_t stream = ctx.stream();
multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream);
}

void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
#define CUDA_MULTI_ADD_BLOCK_SIZE 256

void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

Expand All @@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
35 changes: 35 additions & 0 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_4,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
GGML_METAL_KERNEL_TYPE_MULTI_ADD,
GGML_METAL_KERNEL_TYPE_MULTI_ADD_4,
GGML_METAL_KERNEL_TYPE_MUL,
GGML_METAL_KERNEL_TYPE_MUL_4,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
Expand Down Expand Up @@ -577,6 +579,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD, multi_add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, multi_add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
Expand Down Expand Up @@ -932,6 +936,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_MULTI_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
Expand Down Expand Up @@ -1349,6 +1354,36 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
case GGML_OP_MULTI_ADD:
{
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(dstt == GGML_TYPE_F32);
GGML_ASSERT(ne02 == 1 && ne03 == 1);
GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

int n_expert = dst->op_params[0];
GGML_ASSERT(n_expert >= 2);

id<MTLComputePipelineState> pipeline = nil;
int64_t n = ne0*ne1;
if (ne0%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
n /= 4;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
Expand Down
131 changes: 116 additions & 15 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,44 @@ kernel void kernel_sqr(
dst[tpig] = src0[tpig] * src0[tpig];
}

kernel void kernel_multi_add_4(
device const float4 * src0,
device float4 * dst,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant int64_t & nb01,
constant int & n_expert,
uint tpig[[thread_position_in_grid]]) {

int64_t i0 = tpig % (ne0/4);
int64_t i1 = tpig / (ne0/4);
device float4 * dst_ptr = dst + i1*(nb1/16) + i0;
device const float4 * src_ptr = src0 + i1*(nb01/16) + i0;
float4 sum = src_ptr[0] + src_ptr[ne0/4];
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0/4];
dst_ptr[0] = sum;
}

kernel void kernel_multi_add(
device const float * src0,
device float * dst,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant int64_t & nb01,
constant int & n_expert,
uint tpig[[thread_position_in_grid]]) {

int64_t i0 = tpig % ne0;
int64_t i1 = tpig / ne0;
device float * dst_ptr = dst + i1*nb1/4 + i0;
device const float * src_ptr = src0 + i1*nb01/4 + i0;
float sum = src_ptr[0] + src_ptr[ne0];
for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0];
dst_ptr[0] = sum;
}

kernel void kernel_sum_rows(
device const float * src0,
device float * dst,
Expand Down Expand Up @@ -8197,32 +8235,95 @@ kernel void kernel_mul_mm_id(
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 ntg3[[threads_per_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

const int32_t i02 = tgpig.z;
tgpig.z = 0;

device const uchar * src0 = src0s + i02*nb02;

// row indices
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
uint ntg = ntg3.x * ntg3.y * ntg3.z;
uint n = nei0*nei1;

// TODO: parallelize this loop
int64_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
if (id == i02) {
//if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
//}
_ne1++;
}
}
}
//uint npt = (n + ntg - 1) / ntg;
//uint first = tiitg * npt;
//uint last = first + npt <= n ? first + npt : n;

//uint nhave = 0;
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) ++nhave;
//}
//threadgroup uint * nums = (threadgroup uint *)shared_memory;
//nums[tiitg] = nhave;
//threadgroup_barrier(mem_flags::mem_threadgroup);

//uint nprev = 0;
//for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
//int64_t _ne1 = nprev;
//for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];

//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
//for (uint i = first; i < last; ++i) {
// uint ii0 = i % nei0;
// uint ii1 = i / nei0;
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
//}

//threadgroup_barrier(mem_flags::mem_threadgroup);

//
// The following is slightly faster than the commented out version above
//
uint nhave = 0;
for (uint i = tiitg; i < n; i += ntg) {
uint ii0 = i % nei0;
uint ii1 = i / nei0;
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
if (id == i02) ++nhave;
}
threadgroup uint * nums = (threadgroup uint *)shared_memory;
nums[tiitg] = nhave;
threadgroup_barrier(mem_flags::mem_threadgroup);

uint nprev = 0;
for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
int64_t _ne1 = nprev;
for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];

threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
for (uint i = tiitg; i < n; i += ntg) {
uint ii0 = i % nei0;
uint ii1 = i / nei0;
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
}
threadgroup_barrier(mem_flags::mem_threadgroup);

// This is the original version that is ridiculously slow.
//// row indices
//threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);

//// TODO: parallelize this loop
//int64_t _ne1 = 0;
//for (ushort ii1 = 0; ii1 < nei1; ii1++) {
// for (ushort ii0 = 0; ii0 < nei0; ii0++) {
// int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
// if (id == i02) {
// //if (tiitg == 0) {
// rowids[_ne1] = ushort2(ii0, ii1);
// //}
// _ne1++;
// }
// }
//}

//threadgroup_barrier(mem_flags::mem_threadgroup);

kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,
Expand Down
Loading

0 comments on commit 52874c5

Please sign in to comment.