Skip to content

Commit

Permalink
use SOA in radix sort (pytorch#756)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#756

Using structure of array vs. array of structure helps spatial locality and vectorization.

Reviewed By: jianyuh

Differential Revision: D32350500

fbshipit-source-id: 64be138503b3cf07eca6f4f4a215e04a9ecda9ce
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Nov 12, 2021
1 parent eb23a75 commit 8db978a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 65 deletions.
63 changes: 37 additions & 26 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,14 @@ void batched_csr2csc(
batched_csr_offsets[table_to_feature_offset[t] * B];
int num_non_empty_segments = 0;
if (!batched_csc.weights) {
Radix_Sort_Pair<int>* tmpBuf =
(Radix_Sort_Pair<int>*)fbgemm::fbgemmAlignedAlloc(
64, (NS) * sizeof(Radix_Sort_Pair<int>));
Radix_Sort_Pair<int>* tmpBuf1 =
(Radix_Sort_Pair<int>*)fbgemm::fbgemmAlignedAlloc(
64, (NS) * sizeof(Radix_Sort_Pair<int>));
int* tmpBufKeys = reinterpret_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, 4 * NS * sizeof(int)));
int* tmpBufValues = reinterpret_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int* tmpBuf1Keys = reinterpret_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
int* tmpBuf1Values = reinterpret_cast<int*>(
fbgemm::fbgemmAlignedAlloc(64, NS * sizeof(int)));
const auto FBo = batched_csr_offsets[table_to_feature_offset[t] * B];
for (int feature = table_to_feature_offset[t];
feature < table_to_feature_offset[t + 1];
Expand All @@ -376,14 +378,22 @@ void batched_csr2csc(
int64_t pool_begin = batched_csr_offsets[FBb];
int64_t pool_end = batched_csr_offsets[FBb + 1];
for (int64_t p = pool_begin; p < pool_end; ++p) {
tmpBuf[p - FBo].first = batched_csr_indices[p];
tmpBuf[p - FBo].second = FBs + b;
tmpBufKeys[p - FBo] = batched_csr_indices[p];
tmpBufValues[p - FBo] = FBs + b;
}
}
}

Radix_Sort_Pair<int>* sorted_col_row_index_pairs =
radix_sort_parallel<int>(&tmpBuf[0], &tmpBuf1[0], NS, num_embeddings);
int* sorted_col_row_index_keys = nullptr;
int* sorted_col_row_index_values = nullptr;
std::tie(sorted_col_row_index_keys, sorted_col_row_index_values) =
radix_sort_parallel<int>(
tmpBufKeys,
tmpBufValues,
tmpBuf1Keys,
tmpBuf1Values,
NS,
num_embeddings);

int max_thds = omp_get_max_threads();
int num_uniq[max_thds][64];
Expand All @@ -396,9 +406,10 @@ void batched_csr2csc(
num_uniq[tid][0] = 0;
#pragma omp for schedule(static)
for (int i = 1; i < NS; i++) {
if (sorted_col_row_index_pairs[i].first !=
sorted_col_row_index_pairs[i - 1].first)
if (sorted_col_row_index_keys[i] !=
sorted_col_row_index_keys[i - 1]) {
num_uniq[tid][0]++;
}
}
}
num_uniq[0][0] += 1;
Expand All @@ -415,10 +426,9 @@ void batched_csr2csc(
(int64_t*)fbgemm::fbgemmAlignedAlloc(64, nnz * sizeof(int64_t));

batched_csc.column_segment_ptr[0] = 0;
batched_csc.row_indices[0] = sorted_col_row_index_pairs[0].second % B;
batched_csc.column_segment_indices[0] = sorted_col_row_index_pairs[0].first;
batched_csc.column_segment_ids[0] =
sorted_col_row_index_pairs[0].second / B;
batched_csc.row_indices[0] = sorted_col_row_index_values[0] % B;
batched_csc.column_segment_indices[0] = sorted_col_row_index_keys[0];
batched_csc.column_segment_ids[0] = sorted_col_row_index_values[0] / B;
#pragma omp parallel
{
int tid = omp_get_thread_num();
Expand All @@ -436,20 +446,19 @@ void batched_csr2csc(
#endif

#pragma omp for schedule(static)
for (int i = 1; i < NS; i++) {
for (int i = 1; i < NS; ++i) {
#ifdef FBCODE_CAFFE2
batched_csc.column_segment_ids[i] =
sorted_col_row_index_pairs[i].second / divisor;
batched_csc.row_indices[i] = sorted_col_row_index_pairs[i].second -
sorted_col_row_index_values[i] / divisor;
batched_csc.row_indices[i] = sorted_col_row_index_values[i] -
batched_csc.column_segment_ids[i] * B;
#else
batched_csc.row_indices[i] = sorted_col_row_index_pairs[i].second % B;
batched_csc.row_indices[i] = sorted_col_row_index_values[i] % B;
batched_csc.column_segment_ids[i] =
sorted_col_row_index_pairs[i].second / B;
sorted_col_row_index_values[i] / B;
#endif
if (sorted_col_row_index_pairs[i].first !=
sorted_col_row_index_pairs[i - 1].first) {
*tstart = sorted_col_row_index_pairs[i].first;
if (sorted_col_row_index_keys[i] != sorted_col_row_index_keys[i - 1]) {
*tstart = sorted_col_row_index_keys[i];
*t_offs = i;
tstart++;
t_offs++;
Expand All @@ -464,8 +473,10 @@ void batched_csr2csc(
batched_csc.table_ptr[t + 1] = batched_csc.table_ptr[t] + U;
batched_csc.column_segment_ptr[U] = NS;
column_ptr_curr += NS;
fbgemm::fbgemmAlignedFree(tmpBuf);
fbgemm::fbgemmAlignedFree(tmpBuf1);
fbgemm::fbgemmAlignedFree(tmpBufKeys);
fbgemm::fbgemmAlignedFree(tmpBufValues);
fbgemm::fbgemmAlignedFree(tmpBuf1Keys);
fbgemm::fbgemmAlignedFree(tmpBuf1Values);
} else {
// !batched_csc.weights.empty()
#ifdef FBCODE_CAFFE2
Expand Down
92 changes: 53 additions & 39 deletions fbgemm_gpu/include/fbgemm_gpu/cpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@
#include <cstdint>
#include <utility>

template <typename T>
using Radix_Sort_Pair = std::pair<T, T>;
// histogram size per thread
const int RDX_HIST_SIZE = 256;
constexpr int RDX_HIST_SIZE = 256;
template <typename T>
Radix_Sort_Pair<T>* radix_sort_parallel(
Radix_Sort_Pair<T>* inp_buf,
Radix_Sort_Pair<T>* tmp_buf,
std::pair<T*, T*> radix_sort_parallel(
T* inp_key_buf,
T* inp_value_buf,
T* tmp_key_buf,
T* tmp_value_buf,
int64_t elements_count,
int64_t max_value) {
int maxthreads = omp_get_max_threads();
alignas(64) int histogram[RDX_HIST_SIZE * maxthreads],
histogram_ps[RDX_HIST_SIZE * maxthreads + 1];
if (max_value == 0)
return inp_buf;
if (max_value == 0) {
return std::make_pair(inp_key_buf, inp_value_buf);
}
int num_bits = sizeof(T) * 8 - __builtin_clz(max_value);
unsigned int num_passes = (num_bits + 7) / 8;

Expand All @@ -37,8 +38,10 @@ Radix_Sort_Pair<T>* radix_sort_parallel(
int* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];
int elements_count_4 = elements_count / 4 * 4;
Radix_Sort_Pair<T>* input = inp_buf;
Radix_Sort_Pair<T>* output = tmp_buf;
T* input_keys = inp_key_buf;
T* input_values = inp_value_buf;
T* output_keys = tmp_key_buf;
T* output_values = tmp_value_buf;

for (unsigned int pass = 0; pass < num_passes; pass++) {
// Step 1: compute histogram
Expand All @@ -47,20 +50,20 @@ Radix_Sort_Pair<T>* radix_sort_parallel(

#pragma omp for schedule(static)
for (int64_t i = 0; i < elements_count_4; i += 4) {
T val_1 = input[i].first;
T val_2 = input[i + 1].first;
T val_3 = input[i + 2].first;
T val_4 = input[i + 3].first;
T key_1 = input_keys[i];
T key_2 = input_keys[i + 1];
T key_3 = input_keys[i + 2];
T key_4 = input_keys[i + 3];

local_histogram[(val_1 >> (pass * 8)) & 0xFF]++;
local_histogram[(val_2 >> (pass * 8)) & 0xFF]++;
local_histogram[(val_3 >> (pass * 8)) & 0xFF]++;
local_histogram[(val_4 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_1 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_2 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_3 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_4 >> (pass * 8)) & 0xFF]++;
}
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; i++) {
T val = input[i].first;
local_histogram[(val >> (pass * 8)) & 0xFF]++;
T key = input_keys[i];
local_histogram[(key >> (pass * 8)) & 0xFF]++;
}
}
#pragma omp barrier
Expand All @@ -82,37 +85,48 @@ Radix_Sort_Pair<T>* radix_sort_parallel(
// Step 3: scatter
#pragma omp for schedule(static)
for (int64_t i = 0; i < elements_count_4; i += 4) {
T val_1 = input[i].first;
T val_2 = input[i + 1].first;
T val_3 = input[i + 2].first;
T val_4 = input[i + 3].first;
T bin_1 = (val_1 >> (pass * 8)) & 0xFF;
T bin_2 = (val_2 >> (pass * 8)) & 0xFF;
T bin_3 = (val_3 >> (pass * 8)) & 0xFF;
T bin_4 = (val_4 >> (pass * 8)) & 0xFF;
T key_1 = input_keys[i];
T key_2 = input_keys[i + 1];
T key_3 = input_keys[i + 2];
T key_4 = input_keys[i + 3];
T bin_1 = (key_1 >> (pass * 8)) & 0xFF;
T bin_2 = (key_2 >> (pass * 8)) & 0xFF;
T bin_3 = (key_3 >> (pass * 8)) & 0xFF;
T bin_4 = (key_4 >> (pass * 8)) & 0xFF;
int pos;
pos = local_histogram_ps[bin_1]++;
output[pos] = input[i];
output_keys[pos] = key_1;
output_values[pos] = input_values[i];
pos = local_histogram_ps[bin_2]++;
output[pos] = input[i + 1];
output_keys[pos] = key_2;
output_values[pos] = input_values[i + 1];
pos = local_histogram_ps[bin_3]++;
output[pos] = input[i + 2];
output_keys[pos] = key_3;
output_values[pos] = input_values[i + 2];
pos = local_histogram_ps[bin_4]++;
output[pos] = input[i + 3];
output_keys[pos] = key_4;
output_values[pos] = input_values[i + 3];
}
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; i++) {
T val = input[i].first;
int pos = local_histogram_ps[(val >> (pass * 8)) & 0xFF]++;
output[pos] = input[i];
T key = input_keys[i];
int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;
output_keys[pos] = key;
output_values[pos] = input_values[i];
}
}

Radix_Sort_Pair<T>* temp = input;
input = output;
output = temp;
T* temp = input_keys;
input_keys = output_keys;
output_keys = temp;

temp = input_values;
input_values = output_values;
output_values = temp;
#pragma omp barrier
}
}
return (num_passes % 2 == 0 ? inp_buf : tmp_buf);
return (
num_passes % 2 == 0 ? std::make_pair(inp_key_buf, inp_value_buf)
: std::make_pair(tmp_key_buf, tmp_value_buf));
}

0 comments on commit 8db978a

Please sign in to comment.