Skip to content

Commit

Permalink
add notes for ep 10
Browse files Browse the repository at this point in the history
  • Loading branch information
RussWong committed Apr 23, 2024
1 parent a7c447e commit 46b3117
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
#include "cuda_runtime.h"
typedef __half half;
typedef __half2 half2;

// 注意
// 1. 此融合算子还比较简单,主要融合了几个element wise算子,在L15和L36,把各个子算子的输出保留在寄存器直接喂给下一个算子做计算,而无需中途写回显存
// 2. 下个版本会新增fusedDropout这个稍微复杂点的融合算子,来进一步体会融合算子的开发方法,总的来说,和本节融合算子的思想一样
template<typename T>
struct MaskScaleAndElementwiseAddFunctor {
MaskScaleAndElementwiseAddFunctor(const uint8_t* mask, const T* add_val, float scale)
: mask(mask), add_val(add_val), scale(scale) {}
__device__ T Compute(T x, int64_t i) const {
return x * static_cast<T>(static_cast<bool>(mask[i]) * scale) + add_val[i];
}
// mask和scale先做计算,然后结果再和x做计算,最后element wise相加
return x * static_cast<T>(static_cast<bool>(mask[i]) * scale) + add_val[i];
const uint8_t* mask;
const T* add_val;
float scale;
Expand All @@ -20,15 +22,17 @@ template<>
struct MaskScaleAndElementwiseAddFunctor<half> {
MaskScaleAndElementwiseAddFunctor(const uint8_t* mask, const half* add_val, float scale)
: mask(mask), add_val(add_val), scale(scale) {}
// half标量版本的MaskScaleAndElementwiseAdd,与L15区别不大,注意: 有的GPU在有的nvcc和cuda版本下,没有重载half*half的直接相乘版本,此时需要用hmul(half,half)来代替或者两个half强转为fp32来相乘再转回half,比如(half)((float)x * (float)y)
__device__ half Compute(half x, int64_t i) const {
return x * static_cast<half>(static_cast<bool>(mask[i]) * scale) + add_val[i];
}
// half向量版本的MaskScaleAndElementwiseAdd,不仅支持L31和L32所示的向量化读取,也支持L38所示的向量化计算,这与fp32的向量化是不同的,具体接口可以搜索cuda math api文档
__device__ half2 ComputeHalf2(half2 x, int64_t i) const {
const char2* mask_c2 = reinterpret_cast<const char2*>(mask);
const half2* add_val_h2 = reinterpret_cast<const half2*>(add_val);
char2 mask_val = mask_c2[i];
half2 one_or_zero_h2;
half2 h2_scale = __float2half2_rn(scale);
char2 mask_val = mask_c2[i]; // 向量化读取
half2 one_or_zero_h2; // 向量化读取
half2 h2_scale = __float2half2_rn(scale); // float->half2, ep. 1.0 => (1.0, 1.0)
one_or_zero_h2.x = mask_val.x;
one_or_zero_h2.y = mask_val.y;
return __hadd2(__hmul2(__hmul2(x, one_or_zero_h2), h2_scale), add_val_h2[i]);
Expand All @@ -38,19 +42,20 @@ struct MaskScaleAndElementwiseAddFunctor<half> {
float scale;
};


// biasAdd的输入两个,x.shape={rows, cols}, bias.shape={cols}, 所以需要在L58通过除余循环读取这cols个bias
template<typename FUNCTOR>
__global__ void FusedBiasAddCUDAKernelHalf2(FUNCTOR functor, const int elem_cnt,
const int bias_size, const half* x, const half* bias,
half* y) {
const int h2_elem_cnt = elem_cnt / 2;
const int h2_elem_cnt = elem_cnt / 2; // 读取的粒度由half变成了half2,那自然元素数量就少了一半
const int h2_bias_size = bias_size / 2;
const auto* x_h2 = reinterpret_cast<const half2*>(x);
const auto* x_h2 = reinterpret_cast<const half2*>(x); // 强转为向量指针后在L57读取
const auto* bias_h2 = reinterpret_cast<const half2*>(bias);
auto* y_h2 = reinterpret_cast<half2*>(y);
// 保证有限线程数处理完所有数据
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < h2_elem_cnt;
i += blockDim.x * gridDim.x){
half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]);
half2 x_i = __hadd2(x_h2[i], bias_h2[i % h2_bias_size]); //
y_h2[i] = functor.ComputeHalf2(x_i, i);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <stdio.h>
#include <cuda.h>
#include "cuda_runtime.h"

// 注意: 本kernel注释见fp16版本的fused bias mask scale and add
template<typename T>
struct MaskScaleAndElementwiseAddFunctor {
MaskScaleAndElementwiseAddFunctor(const uint8_t* mask, const T* add_val, float scale)
Expand Down

0 comments on commit 46b3117

Please sign in to comment.