Skip to content

Commit

Permalink
Move permute_pooled_embs_kernel into fbgemm_gpu (pytorch#569)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#569

As title

Reviewed By: evhunter

Differential Revision: D27222442

fbshipit-source-id: 48d57d53a97277b548322e9ca003917e5bf260c5
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 23, 2021
1 parent f3abcb8 commit 3657619
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/layout_transform_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,57 @@ __global__ void recat_copy_async_kernel(
}
}
}

// Kernerl for permute pooled embedding op.
// This kernel is moving D elements per warp.
template <typename scalar_t>
__global__ void permute_pooled_embs_kernel(
const scalar_t* __restrict__ go, // 2D, B x sum(mixed_D)
const int64_t* __restrict__ offset_dim_list, // 1D, T
const int64_t* __restrict__ permute_list, // 1D, T
const int64_t* __restrict__ inv_offset_dim_list, // 1D, T+1
scalar_t* __restrict__ sgo, // 2D, B x sum(mixed_D)
const int64_t B,
const int64_t T,
const int64_t dim_sum) {
int32_t t = blockIdx.x * (blockDim.x / warpSize) + threadIdx.x / warpSize;
int32_t b = blockIdx.y + gridDim.y * blockIdx.z;
int32_t idx = threadIdx.x % warpSize;
int32_t blk = warpSize;
if (b >= B) {
return;
}
if (t >= T) {
return;
}
int64_t permute_idx = permute_list[t];
int64_t input_dim_start = offset_dim_list[permute_idx];
int64_t input_dim_end = offset_dim_list[permute_idx + 1];
int64_t cur_dim = input_dim_end - input_dim_start;
if (idx >= cur_dim) {
return;
}
// Apply the offsets on B dimension.
go += b * dim_sum;
sgo += b * dim_sum;
int64_t sgo_offset = inv_offset_dim_list[t];
// Need to check alignment before using vector code path.
if (fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(&sgo[sgo_offset]) &&
fbgemm_gpu::is_aligned<fbgemm_gpu::Vec4T<scalar_t>>(
&go[input_dim_start])) {
const int32_t vec_size = 4;
int32_t loop_end = cur_dim / (vec_size) * (vec_size);
for (int32_t i = idx * vec_size; i < loop_end; i += blk * vec_size) {
fbgemm_gpu::Vec4T<scalar_t>::copy(
&go[input_dim_start + i], &sgo[sgo_offset + i]);
}
// Use elementwise access for the last incomplete vector.
for (int32_t i = loop_end + idx; i < cur_dim; i += blk) {
sgo[sgo_offset + i] = go[input_dim_start + i];
}
} else { // Fallback if not aligned.
for (int32_t i = idx; i < cur_dim; i += blk) {
sgo[sgo_offset + i] = go[input_dim_start + i];
}
}
}

0 comments on commit 3657619

Please sign in to comment.