Skip to content

Commit

Permalink
optimize exact row-wise sparse adagrad on cpu (pytorch#589)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#589

Use embedding spmdm JIT'ed FBGEMM kernel for gradient coalescing and also use rowwise sparses adagrad JIT'ed kernel

Reviewed By: jiyuanzFB

Differential Revision: D27562367

fbshipit-source-id: 7c0b6f4f2628e19706772e786c8f3ac253eac21b
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 6, 2021
1 parent 1dfb0c3 commit a98ad84
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 73 deletions.
180 changes: 147 additions & 33 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,31 @@
#include <ATen/AccumulateType.h>

#include "codegen/embedding_forward_split_cpu.h"
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm/Types.h"

using namespace at;

namespace internal {
template <typename T>
struct half2float16 {
using type = T;
};

template <>
struct half2float16<at::Half> {
using type = fbgemm::float16;
};
} // namespace internal

namespace {
template <typename scalar_t>
void split_embedding_backward_exact_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
const TensorAccessor<int64_t, 1> weights_offsets_data,
const TensorAccessor<int, 1> D_offsets_data,
Tensor hash_size_cumsum,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Expand All @@ -37,50 +52,146 @@ void split_embedding_backward_exact_cpu_kernel(
{% endif %}
{{ args.split_cpu_kernel_args | join(", ") }}) {
using grad_t = acc_type<scalar_t, true>;
::internal::BatchedHyperCompressedSparseColumn batched_csc;
::internal::batched_csr2csc(
batched_csc,
num_tables,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
indice_weights.defined()
? indice_weights.accessor<grad_t, 1>()
: TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr),
pooling_mode,
table_to_feature_offset);
std::vector<int>& table_ptr = batched_csc.table_ptr;
std::vector<int>& column_ptr = batched_csc.column_ptr;

auto grad_output_data = grad_output.accessor<grad_t, 2>();
// const auto grad_output_accessor = grad_output.accessor<grad_t, 2>();
const grad_t* grad_output_data = grad_output.data_ptr<grad_t>();
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();

const bool has_weights = indice_weights.defined();
auto grad_stride = grad_output.size(1);

const bool has_weights = !batched_csc.weights.empty();
std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs(
num_tables);

at::parallel_for(0, num_tables, 0, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < t_end; ++t) {
int feature_begin = table_to_feature_offset[t];

::internal::batched_csr2csc(
batched_cscs[t],
1,
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
indice_weights.defined()
? indice_weights.accessor<grad_t, 1>()
: TensorAccessor<grad_t, 1>(nullptr, nullptr, nullptr),
pooling_mode,
table_to_feature_offset + t);
}
});

for (int t = 0; t < num_tables; ++t) {
int feature_begin = table_to_feature_offset[t];

int c_begin = batched_cscs[t].table_ptr[0];
int c_end = batched_cscs[t].table_ptr[1];
std::vector<int>& col_segment_ptr = batched_cscs[t].column_segment_ptr;
std::vector<int64_t>& col_segment_indices =
batched_cscs[t].column_segment_indices;

int64_t hash_size;
int t_temp = feature_begin + 1;
do {
hash_size =
hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[feature_begin];
++t_temp;
} while (hash_size == 0);

const auto D_begin = D_offsets_data[feature_begin];
const auto D =
D_offsets_data[feature_begin + 1] - D_offsets_data[feature_begin];
const auto table_begin = weights_offsets_data[feature_begin];
grad_t grad_buffer[D];
for (int c = table_ptr[t]; c < table_ptr[t + 1]; ++c) {
memset(grad_buffer, 0, D * sizeof(grad_t));
int idx = batched_csc.column_indices[c];
const int64_t embedding_begin = table_begin + idx * D;
for (int r = column_ptr[c]; r < column_ptr[c + 1]; ++r) {
int f_times_b = batched_csc.row_indices[r];
int feature = f_times_b / B;
int b = f_times_b % B;
int D_offset = D_begin + (feature - feature_begin) * D;
for (int64_t d = 0; d < D; ++d) {
grad_buffer[d] += has_weights
? grad_output_data[b][D_offset + d] * batched_csc.weights[r]
: grad_output_data[b][D_offset + d];

{% if optimizer == "rowwise_adagrad" %}
constexpr bool use_fbgemm = std::is_same<scalar_t, float>::value;
// || std::is_same<scalar_t, at::Half>::value;
if (use_fbgemm &&
table_to_feature_offset[t + 1] == table_to_feature_offset[t] + 1) {
// fbgemm handles common case of no shared table
using fbgemm_weight_t = typename ::internal::half2float16<scalar_t>::type;
auto spmdm_kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int32_t,
/*OffsetType=*/int32_t>(
D,
!batched_cscs[t].weights.empty(),
/*normalize_by_lengths=*/false,
/*prefetch=*/16,
/*is_weight_positional=*/false,
/*use_offsets=*/true,
/*output_stride=*/-1,
/*input_stride=*/grad_stride);
auto rowwise_adagrad_kernel =
fbgemm::GenerateSparseAdaGrad</*IndexType=*/int64_t>(
D, /*rowwise=*/true);

constexpr int C_BLOCK = 64;
at::parallel_for(c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) {
grad_t grad_blocked_buffer[C_BLOCK * D];
for (int64_t c = c0; c < c1; c += C_BLOCK) {
const int* offsets_begin_ptr = col_segment_ptr.data() + c;
int64_t c_block_end = std::min(c + C_BLOCK, c1);
bool success = spmdm_kernel(
c_block_end - c,
col_segment_ptr[c_block_end] - *offsets_begin_ptr,
B,
reinterpret_cast<const fbgemm_weight_t*>(
grad_output_data + D_begin),
batched_cscs[t].row_indices.data() + *offsets_begin_ptr,
offsets_begin_ptr,
batched_cscs[t].weights.empty()
? nullptr
: batched_cscs[t].weights.data() + *offsets_begin_ptr,
reinterpret_cast<float*>(grad_blocked_buffer));
// TODO: more friendly error msg.
TORCH_CHECK(success);
int num_rows_processed = rowwise_adagrad_kernel(
c_block_end - c,
hash_size * D,
reinterpret_cast<float*>(&host_weights_data[table_begin]),
reinterpret_cast<const float*>(grad_blocked_buffer),
reinterpret_cast<float*>(
&momentum1_host[momentum1_offsets_data[feature_begin]]),
col_segment_indices.data() + c,
eps,
-learning_rate,
/*weight_decay=*/0,
/*counter=*/nullptr,
/*counter_halflife=*/0);
// TODO: more friendly error msg.
TORCH_CHECK(num_rows_processed == c_block_end - c);
} // for each c
}); // parallel for
} else
{% endif %}
{
// no fbgemm
// TODO: to parallelize, we should easily identify segments belong to
// the same column.
grad_t grad_buffer[D];
for (int c = c_begin; c < c_end; ++c) {
int64_t idx = col_segment_indices[c];
if (c == c_begin || col_segment_indices[c - 1] != idx) {
memset(grad_buffer, 0, D * sizeof(grad_t));
}
}
{{ split_weight_update_cpu }}
}
const int64_t embedding_begin = table_begin + idx * D;
int D_offset = D_begin + batched_cscs[t].column_segment_ids[c] * D;
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
int b = batched_cscs[t].row_indices[r];
for (int64_t d = 0; d < D; ++d) {
grad_buffer[d] += !batched_cscs[t].weights.empty()
? grad_output_data[b * grad_stride + D_offset + d] *
batched_cscs[t].weights[r]
: grad_output_data[b * grad_stride + D_offset + d];
}
}
if (c == c_end - 1 || col_segment_indices[c + 1] != idx) {
{{ split_weight_update_cpu }}
}
} // for each c
} // no fbgemm
} // for each table
}

Expand Down Expand Up @@ -200,13 +311,16 @@ void split_embedding_backward_exact_cpu_dense_kernel(
const auto momentum2_offsets_data = momentum2_offsets.accessor<int64_t, 1>();
{% endif %}

grad_output = grad_output.contiguous();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&]() {
split_embedding_backward_exact_cpu_kernel<scalar_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
Expand Down
148 changes: 111 additions & 37 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,53 +346,127 @@ void batched_csr2csc(
const int* table_to_feature_offset) {
batched_csc.num_tables = num_tables;
batched_csc.table_ptr.resize(num_tables + 1);
int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B];
int64_t nnz = batched_csr_offsets[table_to_feature_offset[num_tables] * B] -
batched_csr_offsets[table_to_feature_offset[0] * B];
batched_csc.row_indices.resize(nnz);
bool has_weights = batched_csr_weights.data() != nullptr;
if (has_weights || pooling_mode == MEAN) {
batched_csc.weights.resize(nnz);
}

batched_csc.table_ptr.push_back(0);
batched_csc.column_ptr.push_back(0);
batched_csc.table_ptr[0] = 0;
batched_csc.column_segment_ptr.push_back(0);
int column_ptr_curr = 0;
for (int t = 0; t < num_tables; ++t) {
std::unordered_map<int64_t, std::vector<std::pair<int, scalar_t>>>
non_empty_columns;
for (int feature = table_to_feature_offset[t];
feature < table_to_feature_offset[t + 1];
++feature) {
for (int b = 0; b < B; ++b) {
int64_t pool_begin = batched_csr_offsets[feature * B + b];
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
int64_t L = pool_end - pool_begin;
// MEAN pooling will not work with indice_weights!
double scale_factor =
(pooling_mode == MEAN && !has_weights && L > 0) ? 1.0 / L : 1.0;
for (int64_t p = pool_begin; p < pool_end; ++p) {
non_empty_columns[batched_csr_indices[p]].emplace_back(
feature * B + b,
scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
int num_non_empty_segments = 0;
if (batched_csc.weights.empty()) {
std::unordered_map<int64_t, std::vector<std::vector<int>>>
non_empty_columns;
int f_begin = table_to_feature_offset[t];
int f_end = table_to_feature_offset[t + 1];

for (int feature = f_begin; feature < f_end; ++feature) {
for (int b = 0; b < B; ++b) {
int64_t pool_begin = batched_csr_offsets[feature * B + b];
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
for (int64_t p = pool_begin; p < pool_end; ++p) {
auto itr = non_empty_columns.find(batched_csr_indices[p]);
if (itr == non_empty_columns.end()) {
itr = non_empty_columns
.emplace(
batched_csr_indices[p],
std::vector<std::vector<int>>(f_end - f_begin))
.first;
}
if (itr->second[feature - f_begin].empty()) {
++num_non_empty_segments;
}
itr->second[feature - f_begin].push_back(b);
}
}
}
} // for each feature

batched_csc.table_ptr[t + 1] =
batched_csc.table_ptr[t] + non_empty_columns.size();
batched_csc.column_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
batched_csc.column_indices.reserve(batched_csc.table_ptr[t + 1]);
for (auto const& column : non_empty_columns) {
batched_csc.column_ptr.push_back(column_ptr_curr + column.second.size());
batched_csc.column_indices.push_back(column.first);

for (auto const& non_zero : column.second) {
batched_csc.row_indices[column_ptr_curr] = non_zero.first;
if (!batched_csc.weights.empty()) {
batched_csc.weights[column_ptr_curr] = non_zero.second;
} // for each feature

batched_csc.table_ptr[t + 1] =
batched_csc.table_ptr[t] + num_non_empty_segments;
batched_csc.column_segment_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
batched_csc.column_segment_indices.reserve(batched_csc.table_ptr[t + 1]);
batched_csc.column_segment_ids.reserve(batched_csc.table_ptr[t + 1]);
for (auto const& column : non_empty_columns) {
int feature = f_begin;
for (auto const& column_segment : column.second) {
if (!column_segment.empty()) {
batched_csc.column_segment_ptr.push_back(
column_ptr_curr + column_segment.size());
batched_csc.column_segment_indices.push_back(column.first);
batched_csc.column_segment_ids.push_back(feature - f_begin);
memcpy(
&batched_csc.row_indices[column_ptr_curr],
column_segment.data(),
column_segment.size() * sizeof(int));
column_ptr_curr += column_segment.size();
}
++feature;
} // for each column segment
} // for each column
} else {
// !batched_csc.weights.empty()
std::unordered_map<
int64_t,
std::vector<std::vector<std::pair<int, scalar_t>>>>
non_empty_columns;
int f_begin = table_to_feature_offset[t];
int f_end = table_to_feature_offset[t + 1];
for (int feature = f_begin; feature < f_end; ++feature) {
for (int b = 0; b < B; ++b) {
int64_t pool_begin = batched_csr_offsets[feature * B + b];
int64_t pool_end = batched_csr_offsets[feature * B + b + 1];
int64_t L = pool_end - pool_begin;
// MEAN pooling will not work with indice_weights!
double scale_factor =
(pooling_mode == MEAN && !has_weights && L > 0) ? 1.0 / L : 1.0;
for (int64_t p = pool_begin; p < pool_end; ++p) {
auto itr = non_empty_columns.find(batched_csr_indices[p]);
if (itr == non_empty_columns.end()) {
itr = non_empty_columns
.emplace(
batched_csr_indices[p],
std::vector<std::vector<std::pair<int, scalar_t>>>(
f_end - f_begin))
.first;
}
if (itr->second[feature - f_begin].empty()) {
++num_non_empty_segments;
}
itr->second[feature - f_begin].emplace_back(
b,
scale_factor * (has_weights ? batched_csr_weights[p] : 1.0f));
}
}
++column_ptr_curr;
}
}
} // for each feature

batched_csc.table_ptr[t + 1] =
batched_csc.table_ptr[t] + num_non_empty_segments;
batched_csc.column_segment_ptr.reserve(batched_csc.table_ptr[t + 1] + 1);
batched_csc.column_segment_indices.reserve(batched_csc.table_ptr[t + 1]);
batched_csc.column_segment_ids.reserve(batched_csc.table_ptr[t + 1]);
for (auto const& column : non_empty_columns) {
int feature = f_begin;
for (auto const& column_segment : column.second) {
if (!column_segment.empty()) {
batched_csc.column_segment_ptr.push_back(
column_ptr_curr + column_segment.size());
batched_csc.column_segment_indices.push_back(column.first);
batched_csc.column_segment_ids.push_back(feature - f_begin);
for (auto const& non_zero : column_segment) {
batched_csc.row_indices[column_ptr_curr] = non_zero.first;
batched_csc.weights[column_ptr_curr] = non_zero.second;
++column_ptr_curr;
}
}
++feature;
} // for each column segment
} // for each column
} // !batched_csc.weights.empty()
} // for each matrix (table)

assert(column_ptr_curr == nnz);
Expand Down
Loading

0 comments on commit a98ad84

Please sign in to comment.