Skip to content

Commit

Permalink
add some notes in reduce_v4 and 14_quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
RussWong committed May 21, 2024
1 parent 0790ecb commit 731487f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions 14_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
// 1. printf cpu res and gpu res of each kernel
// 2. use if(tid==0) to get the gpu output and key variable of one thread
// 3. use grid step loop to conveniently debug by launch one thread

// 注意: 正在添加图解
bool CheckResult(float *out, float* groudtruth, int nums){
for (int i = 0; i < nums; i++){
if (groudtruth[i] != out[i]) {
Expand Down Expand Up @@ -187,7 +187,7 @@ __global__ void ReduceMaxMinPerTensor(const T* input_ptr, const int nums, T* max
int gid = blockDim.x * blockIdx.x + tid;
shared_max[tid] = FLT_MIN;
shared_min[tid] = FLT_MAX;

// 1. block数量可能无法覆盖总数据量,先以total_thread_num把block和block范围外的数据给比较一遍
for (int i = gid; i < nums; i += total_thread_num) {
shared_max[tid] = max(shared_max[tid], input_ptr[i]);
shared_min[tid] = min(shared_min[tid], input_ptr[i]);
Expand All @@ -197,15 +197,15 @@ __global__ void ReduceMaxMinPerTensor(const T* input_ptr, const int nums, T* max
//}
}
__syncthreads();

// 2. 至此,所有block已经覆盖总数据量,于是开始在block内部先比较大小,又称intra-block范围的比较
for (int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s && gid < nums) {
shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);
shared_min[tid] = min(shared_min[tid], shared_min[tid + s]);
}
__syncthreads();
}

// 3. 最后,每个block里面的shared mem的0号位置都保存了block内部的最大最小值,此时使用atomic对所有block来进行比较
if (tid == 0) {
atomicMax(max_ptr, shared_max[0]);
atomicMin(min_ptr, shared_min[0]);
Expand Down
4 changes: 2 additions & 2 deletions 5_reduce/reduce_v4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//latency: 0.694ms
__device__ void WarpSharedMemReduce(volatile float* smem, int tid){
// CUDA不保证所有的shared memory读操作都能在写操作之前完成,因此存在竞争关系,可能导致结果错误
// 比如smem[tid] += smem[tid + 16] => smem[3] += smem[16], smem[16] += smem[32]
// 此时L9中smem[16]的读和写到底谁在前谁在后,这是不确定的,所以在Volta架构后最后加入中间寄存器(L11)配合syncwarp保证读写依赖
// 比如smem[tid] += smem[tid + 16] => smem[0] += smem[16], smem[16] += smem[32]
// 此时L9中smem[16]的读和写到底谁在前谁在后,这是不确定的,所以在Volta架构后最后加入中间寄存器(L11)配合syncwarp和volatile(使得不会看见其他线程更新smem上的结果)保证读写依赖
float x = smem[tid];
if (blockDim.x >= 64) {
x += smem[tid + 32]; __syncwarp();
Expand Down

0 comments on commit 731487f

Please sign in to comment.