Skip to content

Commit

Permalink
Refactor groupNormPlugin code
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Dec 7, 2022
1 parent fa3f4b4 commit a02fc46
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 166 deletions.
2 changes: 2 additions & 0 deletions plugin/groupNormPlugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)

file(GLOB CU_SRCS *.cu)
set(STABLE_DIFFUSION_CU_SOURCES ${STABLE_DIFFUSION_CU_SOURCES} ${CU_SRCS})
set(STABLE_DIFFUSION_CU_SOURCES ${STABLE_DIFFUSION_CU_SOURCES} PARENT_SCOPE)
152 changes: 77 additions & 75 deletions plugin/groupNormPlugin/groupNormKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,68 @@
* limitations under the License.
*/

#include "common/common.cuh"
#include "groupNormKernel.h"

#include <cub/cub.cuh>

static inline __device__ __host__ float sigmoid(float x)
{
return 1.f / (1.f + expf(-x));
return 1.F / (1.F + expf(-x));
}

struct Group_sums
struct GroupSums
{
// Is it the 1st element of the group?
int32_t flag;
// The sum.
float sum;
// The sum of squares.
float sum_sq;
float sumSq;
};

struct Group_sums_op
struct GroupSumsOp
{
inline __device__ Group_sums operator()(Group_sums const& a, Group_sums const& b)
inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b)
{
Group_sums dst;
GroupSums dst;
dst.sum = b.flag ? b.sum : (a.sum + b.sum);
dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq);
dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq);
dst.flag = a.flag + b.flag;
return dst;
}
};

template <int32_t THREADS_PER_BLOCK>
__global__ void group_norm_nhwc_sum_kernel(Group_norm_nhwc_params params)
template <int32_t tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params)
{
// The object in charge of doing the sums for the different blocks.
typedef cub::BlockScan<Group_sums, THREADS_PER_BLOCK> Block_scan;
typedef cub::BlockScan<GroupSums, tTHREADS_PER_BLOCK> BlockScan;

// Allocate shared memory for Block_scan.
__shared__ typename Block_scan::TempStorage temp_storage;
// Allocate shared memory for BlockScan.
__shared__ typename BlockScan::TempStorage tempStorage;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__ float2 smem[THREADS_PER_BLOCK];
__shared__ float2 smem[tTHREADS_PER_BLOCK];

// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.c_per_block + threadIdx.x * 2;
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;

// The first activation loaded by that block.
int32_t hw_begin = blockIdx.y * params.hw_per_block;
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);

// The sums.
float sum = 0.f, sum_sq = 0.f;
float sum = 0.F;
float sumSq = 0.F;

// Iterate over the activations to compute the sums.
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi)
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi)
{
// The offset.
int64_t offset = (int64_t) ni * params.hwc + hwi * params.c + ci;
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwi) * params.c + ci;

// Fetch two channels per thread.
__half2 h2(0, 0);
Expand All @@ -89,35 +91,35 @@ __global__ void group_norm_nhwc_sum_kernel(Group_norm_nhwc_params params)
// Update the sum.
sum += f2.x + f2.y;
// Update the sum of squares.
sum_sq += f2.x * f2.x + f2.y * f2.y;
sumSq += f2.x * f2.x + f2.y * f2.y;
}

// The group that thread works on and the channel in the group (modulus).
int32_t gi = threadIdx.x * 2 / params.c_per_group;
int32_t cj = threadIdx.x * 2 - params.c_per_group * gi;
int32_t gi = threadIdx.x * 2 / params.cPerGroup;
int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi;

// The data for the summations.
Group_sums inp{cj == 0 ? 1 : 0, sum, sum_sq};
GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq};

// Do the segmented scan.
Group_sums out;
Block_scan(temp_storage).InclusiveScan(inp, out, Group_sums_op());
GroupSums out;
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());

// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if (cj == params.c_per_group - 2 /* 2 channels per thread */)
if (cj == params.cPerGroup - 2 /* 2 channels per thread */)
{
smem[gi] = make_float2(out.sum, out.sum_sq);
smem[gi] = make_float2(out.sum, out.sumSq);
}

// Make sure the data is in shared memory.
__syncthreads();

// The global group index.
int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x;
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;

// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groups_per_block || gj >= params.groups)
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups)
{
return;
}
Expand All @@ -126,78 +128,78 @@ __global__ void group_norm_nhwc_sum_kernel(Group_norm_nhwc_params params)
float2 sums = smem[threadIdx.x];

// Store to global memory.
atomicAdd(&params.red_buffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.red_buffer[(2 * ni + 1) * params.groups + gj], sums.y);
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}

void group_norm_nhwc_sum(Group_norm_nhwc_params const& params, cudaStream_t stream)
void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream)
{
// Make sure the values are as we expect.
assert(params.c % params.c_per_block == 0 && params.hw % params.hw_per_block == 0);
PLUGIN_ASSERT(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0);
// Make sure a group does not span multiple blocks.
assert(params.c_per_block % params.c_per_group == 0);
PLUGIN_ASSERT(params.cPerBlock % params.cPerGroup == 0);

dim3 grid;

// The number of blocks to compute all the channels.
grid.x = params.c / params.c_per_block;
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hw_per_block);
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;

switch (params.c_per_block)
switch (params.cPerBlock)
{
case 320: group_norm_nhwc_sum_kernel<160><<<grid, 160, 0, stream>>>(params); break;
case 480: group_norm_nhwc_sum_kernel<256><<<grid, 256, 0, stream>>>(params); break;
case 256: group_norm_nhwc_sum_kernel<128><<<grid, 128, 0, stream>>>(params); break;
case 128: group_norm_nhwc_sum_kernel<64><<<grid, 64, 0, stream>>>(params); break;
default: assert(false); // Not implemented!
case 320: groupNormNHWCSumKernel<160><<<grid, 160, 0, stream>>>(params); break;
case 480: groupNormNHWCSumKernel<256><<<grid, 256, 0, stream>>>(params); break;
case 256: groupNormNHWCSumKernel<128><<<grid, 128, 0, stream>>>(params); break;
case 128: groupNormNHWCSumKernel<64><<<grid, 64, 0, stream>>>(params); break;
default: PLUGIN_FAIL("Not implemented");
}

PLUGIN_CUASSERT(cudaGetLastError());
}

template <int32_t THREADS_PER_BLOCK>
__global__ void group_norm_nhwc_scale_kernel(Group_norm_nhwc_params params)
template <int32_t tTHREADS_PER_BLOCK>
__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params)
{
// The instance in the batch.
int32_t ni = blockIdx.z;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t ci = blockIdx.x * params.c_per_block + threadIdx.x * 2;
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2;
// The group that thread works on and the channel in the group (modulus).
int32_t gi = ci / params.c_per_group;
int32_t gi = ci / params.cPerGroup;

// Load the sum and sum of squares for the group.
float sum = 0.f, sum_sq = 0.f;
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups)
{
sum = params.red_buffer[(2 * ni + 0) * params.groups + gi];
sum_sq = params.red_buffer[(2 * ni + 1) * params.groups + gi];
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
}

// Load gamma/beta.
float2 gamma_f2, beta_f2;
float2 gammaF2, betaF2;
if (ci < params.c)
{
gamma_f2 = *reinterpret_cast<float2 const*>(&params.gamma[ci]);
beta_f2 = *reinterpret_cast<float2 const*>(&params.beta[ci]);
gammaF2 = *reinterpret_cast<float2 const*>(&params.gamma[ci]);
betaF2 = *reinterpret_cast<float2 const*>(&params.beta[ci]);
}

// Compute the mean.
float mean = sum * params.inv_hwc;
float mean = sum * params.invHWC;
// Compute the variance.
float var = sum_sq * params.inv_hwc - (mean * mean);
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float inv_stddev = var <= 0.f ? 1.f : rsqrtf(var);
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var);

// The first activation loaded by that block.
int32_t hw_begin = blockIdx.y * params.hw_per_block;
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw);
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);

// Iterate over the activations to compute the sums.
for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi)
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi)
{
// The src/dst offset.
int64_t offset = (int64_t) ni * params.hwc + hwi * params.c + ci;
Expand All @@ -213,15 +215,15 @@ __global__ void group_norm_nhwc_scale_kernel(Group_norm_nhwc_params params)
float2 f2 = __half22float2(h2);

// Normalize the channels.
f2.x = (f2.x - mean) * inv_stddev;
f2.y = (f2.y - mean) * inv_stddev;
f2.x = (f2.x - mean) * invStdDev;
f2.y = (f2.y - mean) * invStdDev;

// Scale by gamma and add beta.
f2.x = gamma_f2.x * f2.x + beta_f2.x;
f2.y = gamma_f2.y * f2.y + beta_f2.y;
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;

// Apply Swish if needed.
if (params.with_swish)
if (params.withSwish)
{
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
Expand All @@ -235,29 +237,29 @@ __global__ void group_norm_nhwc_scale_kernel(Group_norm_nhwc_params params)
}
}

void group_norm_nhwc_scale(Group_norm_nhwc_params const& params, cudaStream_t stream)
void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream)
{
// Make sure the dimensions are aligned with what we expect.
assert(params.c % params.c_per_block == 0);
PLUGIN_ASSERT(params.c % params.cPerBlock == 0);
// Make sure a group does not span multiple blocks.
assert(params.c_per_block % params.c_per_group == 0);
PLUGIN_ASSERT(params.cPerBlock % params.cPerGroup == 0);

dim3 grid;

// The number of blocks to compute all the channels.
grid.x = params.c / params.c_per_block;
grid.x = params.c / params.cPerBlock;
// The number of blocks to compute all the activations in a given instance.
grid.y = divUp(params.hw, params.hw_per_block);
grid.y = divUp(params.hw, params.hwPerBlock);
// The number of instances.
grid.z = params.n;

switch (params.c_per_block)
switch (params.cPerBlock)
{
case 320: group_norm_nhwc_scale_kernel<160><<<grid, 160, 0, stream>>>(params); break;
case 480: group_norm_nhwc_scale_kernel<256><<<grid, 256, 0, stream>>>(params); break;
case 256: group_norm_nhwc_scale_kernel<128><<<grid, 128, 0, stream>>>(params); break;
case 128: group_norm_nhwc_scale_kernel<64><<<grid, 64, 0, stream>>>(params); break;
default: assert(false); // Not implemented!
case 320: groupNormNHWCScaleKernel<160><<<grid, 160, 0, stream>>>(params); break;
case 480: groupNormNHWCScaleKernel<256><<<grid, 256, 0, stream>>>(params); break;
case 256: groupNormNHWCScaleKernel<128><<<grid, 128, 0, stream>>>(params); break;
case 128: groupNormNHWCScaleKernel<64><<<grid, 64, 0, stream>>>(params); break;
default: PLUGIN_FAIL("Not implemented");
}

PLUGIN_CUASSERT(cudaGetLastError());
Expand Down
14 changes: 7 additions & 7 deletions plugin/groupNormPlugin/groupNormKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
#ifndef TRT_GROUPNORM_KERNEL_H
#define TRT_GROUPNORM_KERNEL_H

#include "cuda_fp16.h"
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <stdint.h>

#include "common/checkMacrosPlugin.h"
#include "groupNormPluginCommon.h"

#include <cstdint>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>

using half = __half;

static inline int32_t divUp(int32_t m, int32_t n)
{
return (m + n - 1) / n;
}

void group_norm_nhwc_sum(Group_norm_nhwc_params const& params, cudaStream_t stream);
void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream);

void group_norm_nhwc_scale(Group_norm_nhwc_params const& params, cudaStream_t stream);
void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream);

#endif // TRT_GROUPNORM_KERNEL_H
Loading

0 comments on commit a02fc46

Please sign in to comment.