Skip to content

Commit

Permalink
avoid using direct tensor access (pytorch#970)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#970

Should use data_ptr or accessor instead

Reviewed By: jasonjk-park

Differential Revision: D34743721

fbshipit-source-id: e137e5756f3b360c8c238bbee5c5e318434678a5
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 9, 2022
1 parent acc2db3 commit 1baf483
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) {
return output;
}

template <class T>
template <typename index_t, typename scalar_t>
void reorder_batched_ad_lengths_(
const Tensor& cat_ad_lengths,
const Tensor& batch_offsets,
Expand All @@ -813,7 +813,9 @@ void reorder_batched_ad_lengths_(
const int64_t nB = batch_offsets.numel() - 1;
const int64_t nT = cat_ad_lengths.numel() / num_ads_in_batch;

const auto* batch_offsets_data = batch_offsets.data_ptr<T>();
const auto* batch_offsets_data = batch_offsets.data_ptr<index_t>();
const auto* cat_ad_lengths_data = cat_ad_lengths.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();

for (auto b = 0; b < nB; b++) {
const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b];
Expand All @@ -823,8 +825,8 @@ void reorder_batched_ad_lengths_(
const int32_t output_segment_start =
t * num_ads_in_batch + batch_offsets_data[b];
for (auto i = 0; i < num_ads_b; i++) {
output[output_segment_start + i] =
cat_ad_lengths[input_segment_start + i];
output_data[output_segment_start + i] =
cat_ad_lengths_data[input_segment_start + i];
}
}
}
Expand All @@ -838,14 +840,19 @@ Tensor reorder_batched_ad_lengths_cpu(
TENSOR_ON_CPU(batch_offsets);

Tensor reordered_cat_ad_lengths = at::empty_like(cat_ad_lengths);
AT_DISPATCH_ALL_TYPES(
batch_offsets.type(), "reorder_batched_ad_lengths_cpu_kernel", ([&] {
reorder_batched_ad_lengths_<scalar_t>(
cat_ad_lengths,
batch_offsets,
num_ads_in_batch,
reordered_cat_ad_lengths);
}));
AT_DISPATCH_INDEX_TYPES(
batch_offsets.type(), "reorder_batched_ad_lengths_cpu_kernel1", [&] {
AT_DISPATCH_ALL_TYPES(
cat_ad_lengths.type(),
"reorder_batched_ad_lengths_cpu_kernel2",
[&] {
reorder_batched_ad_lengths_<index_t, scalar_t>(
cat_ad_lengths,
batch_offsets,
num_ads_in_batch,
reordered_cat_ad_lengths);
});
});

return reordered_cat_ad_lengths;
}
Expand All @@ -866,6 +873,7 @@ void reorder_batched_ad_indices_cpu_(
const auto* reordered_cat_ad_offsets_data =
reordered_cat_ad_offsets.data_ptr<int32_t>();
const auto* cat_ad_indices_data = cat_ad_indices.data_ptr<T>();
auto* output_data = output.data_ptr<T>();

for (auto b = 0; b < nB; b++) {
const auto num_ads_b = batch_offsets_data[b + 1] - batch_offsets_data[b];
Expand All @@ -886,7 +894,7 @@ void reorder_batched_ad_indices_cpu_(
reordered_cat_ad_offsets_data[output_segment_offset_start];

for (auto i = 0; i < input_segment_end - input_segment_start; i++) {
output[output_segment_start + i] =
output_data[output_segment_start + i] =
cat_ad_indices_data[input_segment_start + i];
}
}
Expand Down

0 comments on commit 1baf483

Please sign in to comment.