Skip to content

Commit

Permalink
BatchedHyperCompressedSparseColumn -> HyperCompressedSparseColumn (py…
Browse files Browse the repository at this point in the history
…torch#1137)

Summary:
Pull Request resolved: pytorch#1137

Code refactoring reflecting BatchedHyperCompressedSparseColumn is no longer used in batch since we're no longer parallelizing CSR2CSC across tables.

Reviewed By: jiyuanzFB

Differential Revision: D36618687

fbshipit-source-id: 87bb90dc5577e8c3117d32fbdfca25da423bd51a
  • Loading branch information
jspark1105 authored and facebook-github-bot committed May 27, 2022
1 parent 60e438d commit 99a6a5f
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 137 deletions.
40 changes: 19 additions & 21 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ void split_embedding_backward_exact_cpu_kernel(
const bool has_weights = indice_weights.defined();
auto grad_stride = grad_output.size(1);

std::vector<::internal::BatchedHyperCompressedSparseColumn> batched_cscs(
num_tables);
std::vector<::internal::HyperCompressedSparseColumn> cscs(num_tables);

auto get_hash_size = [&hash_size_cumsum_data](int feature_begin) {
int64_t hash_size;
Expand All @@ -83,8 +82,8 @@ void split_embedding_backward_exact_cpu_kernel(
int feature_begin = table_to_feature_offset[t];
int64_t hash_size = get_hash_size(feature_begin);

::internal::batched_csr2csc(
batched_cscs[t],
::internal::csr2csc(
cscs[t],
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
Expand All @@ -99,10 +98,9 @@ void split_embedding_backward_exact_cpu_kernel(
for (int t = 0; t < num_tables; ++t) {
int feature_begin = table_to_feature_offset[t];

int c_begin = batched_cscs[t].table_ptr[0];
int c_end = batched_cscs[t].table_ptr[1];
int* col_segment_ptr = batched_cscs[t].column_segment_ptr;
int* col_segment_indices = batched_cscs[t].column_segment_indices;
int num_non_zero_columns = cscs[t].num_non_zero_columns;
int* col_segment_ptr = cscs[t].column_segment_ptr;
int* col_segment_indices = cscs[t].column_segment_indices;

auto hash_size = get_hash_size(feature_begin);

Expand All @@ -125,7 +123,7 @@ void split_embedding_backward_exact_cpu_kernel(
/*IndexType=*/int32_t,
/*OffsetType=*/int32_t>(
D,
batched_cscs[t].weights != nullptr,
cscs[t].weights != nullptr,
/*normalize_by_lengths=*/false,
/*prefetch=*/16,
/*is_weight_positional=*/false,
Expand All @@ -136,7 +134,7 @@ void split_embedding_backward_exact_cpu_kernel(
fbgemm::GenerateSparseAdaGrad</*IndexType=*/int>(D, /*rowwise=*/true);

constexpr int C_BLOCK = 64;
at::parallel_for(c_begin, c_end, C_BLOCK, [&](int64_t c0, int64_t c1) {
at::parallel_for(0, num_non_zero_columns, C_BLOCK, [&](int64_t c0, int64_t c1) {
grad_t grad_blocked_buffer[C_BLOCK * D];
for (int64_t c = c0; c < c1; c += C_BLOCK) {
const int* offsets_begin_ptr = col_segment_ptr + c;
Expand All @@ -147,11 +145,11 @@ void split_embedding_backward_exact_cpu_kernel(
B,
reinterpret_cast<const fbgemm_weight_t*>(
grad_output_data + D_begin),
batched_cscs[t].row_indices + *offsets_begin_ptr,
cscs[t].row_indices + *offsets_begin_ptr,
offsets_begin_ptr,
batched_cscs[t].weights == nullptr
cscs[t].weights == nullptr
? nullptr
: batched_cscs[t].weights + *offsets_begin_ptr,
: cscs[t].weights + *offsets_begin_ptr,
reinterpret_cast<float*>(grad_blocked_buffer));

if (!success) {
Expand All @@ -161,7 +159,7 @@ void split_embedding_backward_exact_cpu_kernel(
c,
c_block_end,
col_segment_ptr,
batched_cscs[t].row_indices,
cscs[t].row_indices,
hash_size,
/*allow_minus_one=*/false);
}
Expand Down Expand Up @@ -193,28 +191,28 @@ void split_embedding_backward_exact_cpu_kernel(
// TODO: to parallelize, we should easily identify segments belong to
// the same column.
at::acc_type<grad_t, true> grad_buffer[D];
for (int c = c_begin; c < c_end; ++c) {
for (int c = 0; c < num_non_zero_columns; ++c) {
int64_t idx = col_segment_indices[c];
if (c == c_begin || col_segment_indices[c - 1] != idx) {
if (c == 0 || col_segment_indices[c - 1] != idx) {
memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>));
}
const int64_t embedding_begin = table_begin + idx * D;
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
int D_offset = D_begin;
if (is_shared_table) {
D_offset += batched_cscs[t].column_segment_ids[r] * D;
D_offset += cscs[t].column_segment_ids[r] * D;
}
int b = batched_cscs[t].row_indices[r];
int b = cscs[t].row_indices[r];
for (int64_t d = 0; d < D; ++d) {
if (batched_cscs[t].weights != nullptr) {
if (cscs[t].weights != nullptr) {
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
batched_cscs[t].weights[r];
cscs[t].weights[r];
} else {
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d];
}
}
}
if (c == c_end - 1 || col_segment_indices[c + 1] != idx) {
if (c == num_non_zero_columns - 1 || col_segment_indices[c + 1] != idx) {
{{ split_weight_update_cpu }}
}
} // for each c
Expand Down
Loading

0 comments on commit 99a6a5f

Please sign in to comment.