Skip to content

Commit

Permalink
fix bugs where kernel read mat and vec
Browse files Browse the repository at this point in the history
  • Loading branch information
RussWong committed Mar 5, 2024
1 parent be97230 commit 905cf7c
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions 15_gemv/15_gemv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ __global__ void gemv(float* matrix, float* vector, float* res, int cols) {

float thread_local_sum = 0.0f;
for(int i = 0; i < VECS_PER_THREAD; i++) {
float4* mat4 = reinterpret_cast<float4*>(&matrix[bid * cols + i * VEC_SIZE * blockDim.x + tid * VEC_SIZE]); // 4 * half2
float4* vec4 = reinterpret_cast<float4*>(&vector[tid * VEC_SIZE]);
thread_local_sum += mat4[i].x * vec4[i].x;
thread_local_sum += mat4[i].y * vec4[i].y;
thread_local_sum += mat4[i].z * vec4[i].z;
thread_local_sum += mat4[i].w * vec4[i].w;
float4 mat4 = reinterpret_cast<float4*>(matrix)[bid * (cols / VECS_SIZE) + i * blockDim.x + tid]; // 1 * float4
float4 vec4 = reinterpret_cast<float4*>(vector)[i * blockDim.x + tid];
thread_local_sum += mat4.x * vec4.x;
thread_local_sum += mat4.y * vec4.y;
thread_local_sum += mat4.z * vec4.z;
thread_local_sum += mat4.w * vec4.w;
}
//reduce to get the final val
float reduce_res = blockReduce<SumOp, float>(thread_local_sum);
Expand All @@ -104,16 +104,16 @@ __global__ void gemv(half* matrix, half* vector, half* res, int cols) {
//float thread_local_sum = 0.0f;
half thread_local_sum = 0;
for(int i = 0; i < VECS_PER_THREAD; i++) {
float4* mat4 = reinterpret_cast<float4*>(&matrix[bid * cols + i * VEC_SIZE * blockDim.x + tid * VEC_SIZE]); // 4 * half2
float4* vec4 = reinterpret_cast<float4*>(&vector[tid * VEC_SIZE]);
half2* vec_h1 = (half2*)&vec4[i].x;
half2* vec_h2 = (half2*)&vec4[i].y;
half2* vec_h3 = (half2*)&vec4[i].z;
half2* vec_h4 = (half2*)&vec4[i].w;
half2* mat_h1 = (half2*)&mat4[i].x;
half2* mat_h2 = (half2*)&mat4[i].y;
half2* mat_h3 = (half2*)&mat4[i].z;
half2* mat_h4 = (half2*)&mat4[i].w;
float4 mat4 = reinterpret_cast<float4*>(matrix)[bid * (cols / VECS_SIZE) + i * blockDim.x + tid]; // 4 * half2
float4 vec4 = reinterpret_cast<float4*>(vector)[i * blockDim.x + tid];
half2* vec_h1 = (half2*)&vec4.x;
half2* vec_h2 = (half2*)&vec4.y;
half2* vec_h3 = (half2*)&vec4.z;
half2* vec_h4 = (half2*)&vec4.w;
half2* mat_h1 = (half2*)&mat4.x;
half2* mat_h2 = (half2*)&mat4.y;
half2* mat_h3 = (half2*)&mat4.z;
half2* mat_h4 = (half2*)&mat4.w;
half2 res1 = __hmul2(*mat_h1, *vec_h1);
half2 res2 = __hmul2(*mat_h2, *vec_h2);
half2 res3 = __hmul2(*mat_h3, *vec_h3);
Expand All @@ -132,19 +132,16 @@ __global__ void gemv(half* matrix, half* vector, half* res, int cols) {
// thread_local_sum += res3.y;
// thread_local_sum += res4.x;
// thread_local_sum += res4.y;
if(i == 0 && tid == 0 && bid == 0) {
printf("thread sum = %f\n", (float)thread_local_sum); // 8
//if(i == 0 && tid == 0 && bid == 0) {
//printf("thread sum = %f\n", (float)thread_local_sum); // 8
// printf("res1.x = %f\n", res1.x); // 1
// printf("res1.y = %f\n", res1.y);
}
//}
}
//reduce to get the final val
half reduce_res = blockReduce<SumOp, half>(thread_local_sum);
// float reduce_res = blockReduce<SumOp, float>(thread_local_sum);
//store to gmem
if(tid == 0) {
printf("block reduce_res = %f\n", (float)reduce_res);
// res[blockIdx.x] = __float2half(reduce_res);
res[blockIdx.x] = reduce_res;
}
__syncthreads();
Expand Down

0 comments on commit 905cf7c

Please sign in to comment.