Skip to content

Commit

Permalink
row-wise sequence embeddings
Browse files Browse the repository at this point in the history
Summary:
Design doc: https://fb.quip.com/6lgwApu6q46w
For row-wise partition, the sequence embeddings pipeline is
```
(T_g, W, B_local, L_bucket x D) same bucket of local batches → permute → (W, T_g, B_local, L_bucket x D) → a2a on bucketized lengths →
(W, T_g, B_local, L_bucket x D) all buckets of local batches → debucketize → (T_g, B_local, L_batch x D)
```
To enable this pipeline, we
- use `torch.index_select` to permute the embeddings.
- produced the permute mapping for bucketize and unbucketize embeddings and its gradients.
- added these metadata in the rw partition workflow.

Differential Revision: D27570196

fbshipit-source-id: b061c0c92b65f710c598d28d441bc287a74a217c
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Apr 14, 2021
1 parent c609f3c commit 0b3f25e
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ __global__ void _block_bucketize_sparse_features_cuda_kernel1(
// (sparse_feature is partitioned continuously along the sparse dimension into
// my_size blocks)
template <
bool sequence,
bool has_weight,
bool bucketize_pos,
typename index_t,
Expand All @@ -178,7 +179,8 @@ __global__ void _block_bucketize_sparse_features_cuda_kernel2(
index_t* __restrict__ new_offsets_data,
index_t* __restrict__ new_indices_data,
scalar_t* __restrict__ new_weights_data,
index_t* __restrict__ new_pos_data) {
index_t* __restrict__ new_pos_data,
index_t* __restrict__ unbucketize_permute_data) {
int32_t b_t_start = (int32_t)blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
for (int b_t = b_t_start; b_t < lengths_size; b_t += stride) {
Expand All @@ -193,6 +195,9 @@ __global__ void _block_bucketize_sparse_features_cuda_kernel2(
index_t pos = new_offsets_data[p * lengths_size + b_t];
new_indices_data[pos] = new_idx;
new_offsets_data[p * lengths_size + b_t]++;
if (sequence) {
unbucketize_permute_data[i] = pos;
}
if (has_weight) {
new_weights_data[pos] = weights_data[i];
}
Expand Down Expand Up @@ -287,15 +292,15 @@ __global__ void permute_lengths_kernel(
}
}

// Kernel for permuting the indices and weights. Used for permutation of table-wise partitioned sequence embeddings

// Kernel for permuting the indices and weights. Used for permutation of sparse
// features.
template <typename index_t, typename scalar_t>
__global__ void permute_embeddings_kernel(
int32_t len,
int32_t T,
int32_t B,
const scalar_t* __restrict__ embeddings,
// bag level permute
const int32_t* __restrict__ permute,
const index_t* __restrict__ input_offsets,
const index_t* __restrict__ output_offsets,
Expand Down

0 comments on commit 0b3f25e

Please sign in to comment.