Skip to content

Commit

Permalink
Implement some custom fb op out variant kernels (pytorch#2793)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#24

Pull Request resolved: pytorch#2793

Implement an out variant version of tbe_input_combine_with_length, offsets_to_lengths, and lenghts_to_offsets, and add a skeleton for custom fb op static kernel dispatch in sigmoid.

Also start adding native kernels (ie. kernels with no out variant but for which we can directly go to the native implementation rather than first going torch torch dispatcher)

Reviewed By: sryap

Differential Revision: D57453462

fbshipit-source-id: aa62ca7340ea3c03304f6c0f65f17b3a15d6d049
  • Loading branch information
qxy11 authored and facebook-github-bot committed Jul 11, 2024
1 parent 98fa998 commit 2bd3222
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 36 deletions.
7 changes: 7 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/input_combine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@ tbe_input_combine_with_length_cuda(
const uint64_t max_list_size,
const c10::DeviceIndex& device);

void tbe_input_combine_with_length_cpu_out(
at::Tensor& combined_indices,
at::Tensor& combined_lengths,
at::Tensor& combined_per_sample_weights,
const std::vector<at::Tensor>& indices_list,
const std::vector<at::Tensor>& lengths_list,
const std::vector<at::Tensor>& per_sample_weights);
} // namespace fbgemm_gpu
143 changes: 107 additions & 36 deletions fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

Tensor _cat_int_tensors(
void _cat_int_tensors_out(
Tensor& combined_tensors,
const std::vector<Tensor>& tensor_list,
int64_t total_num,
bool use_pin_memory) {
auto combined_tensors = at::empty(
{total_num},
at::TensorOptions()
.dtype(c10::kInt)
.device(tensor_list[0].device())
.pinned_memory(use_pin_memory));

int64_t total_num) {
at::native::resize_(combined_tensors, {total_num});
auto* combined_tensors_data_ptr =
combined_tensors.mutable_data_ptr<int32_t>();
size_t idx = 0;
Expand All @@ -53,6 +47,22 @@ Tensor _cat_int_tensors(
}
});
}
}

Tensor _cat_int_tensors(
const std::vector<Tensor>& tensor_list,
int64_t total_num,
bool use_pin_memory) {
// Using int type to maintain original behavior
// in https://fburl.com/code/h2lwews2
auto combined_tensors = at::empty(
{total_num},
at::TensorOptions()
.dtype(c10::kInt)
.device(tensor_list[0].device())
.pinned_memory(use_pin_memory));

_cat_int_tensors_out(combined_tensors, tensor_list, total_num);
return combined_tensors;
}

Expand Down Expand Up @@ -89,29 +99,41 @@ Tensor _cat_int_tensors_with_padding(
return combined_tensors;
}

Tensor _cat_per_sample_weights_list(
void _cat_per_sample_weights_list_out(
Tensor& out,
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num,
bool use_pin_memory) {
auto combined_weights = at::ones(
{total_num},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(use_pin_memory));
auto* combined_weights_ptr = combined_weights.mutable_data_ptr<float>();
int64_t total_num) {
at::native::resize_(out, {total_num});
out.fill_(1.);

auto* out_weights_ptr = out.mutable_data_ptr<float>();

for (size_t i = 0; i < per_sample_weights.size(); i++) {
auto element_size = per_sample_weights[i].numel();
if (element_size != 0) {
memcpy(
combined_weights_ptr,
out_weights_ptr,
per_sample_weights[i].data_ptr<float>(),
element_size * sizeof(float));
}
combined_weights_ptr += indices_list[i].numel();
out_weights_ptr += indices_list[i].numel();
}
}

Tensor _cat_per_sample_weights_list(
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num,
bool use_pin_memory) {
auto combined_weights = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(use_pin_memory));
_cat_per_sample_weights_list_out(
combined_weights, per_sample_weights, indices_list, total_num);
return combined_weights;
}

Expand Down Expand Up @@ -200,7 +222,10 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
return {combined_indices, combined_offsets, at::empty({0})};
}

std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
void tbe_input_combine_with_length_cpu_out(
Tensor& combined_indices,
Tensor& combined_lengths,
Tensor& combined_per_sample_weights,
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& lengths_list,
const std::vector<Tensor>& per_sample_weights) {
Expand All @@ -210,7 +235,6 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
int64_t total_indices = 0;
int64_t total_lengths = 0;
bool need_weights = false;
bool pin_memory = false;

for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
Expand All @@ -234,20 +258,67 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
}
}

auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);

auto combined_lengths =
_cat_int_tensors(lengths_list, total_lengths, pin_memory);

_cat_int_tensors_out(combined_indices, indices_list, total_indices);
_cat_int_tensors_out(combined_lengths, lengths_list, total_lengths);
if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_lengths),
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
_cat_per_sample_weights_list_out(
combined_per_sample_weights,
per_sample_weights,
indices_list,
total_indices);
return;
}
combined_per_sample_weights.resize_({0});
}

std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& lengths_list,
const std::vector<Tensor>& per_sample_weights) {
constexpr bool pin_memory = false;
const auto num_lists = indices_list.size();
TORCH_CHECK_GT(indices_list.size(), 0);
TORCH_CHECK_EQ(lengths_list.size(), indices_list.size());
TORCH_CHECK_EQ(per_sample_weights.size(), indices_list.size());
for (const auto i : c10::irange(num_lists)) {
TENSOR_CONTIGUOUS_AND_ON_CPU(indices_list[i]);
TENSOR_CONTIGUOUS_AND_ON_CPU(lengths_list[i]);
if (per_sample_weights[i].numel() > 0) {
TENSOR_CONTIGUOUS_AND_ON_CPU(per_sample_weights[i]);
} else {
TENSOR_EMPTY_OR_ON_CPU(per_sample_weights[i]);
}
}
return {combined_indices, combined_lengths, at::empty({0})};
// Using int type to maintain original behavior
// in https://fburl.com/code/h2lwews2
auto combined_indices = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kInt)
.device(indices_list[0].device())
.pinned_memory(pin_memory));
auto combined_lengths = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kInt)
.device(lengths_list[0].device())
.pinned_memory(pin_memory));
// Using float type to maintain original behavior
// in https://fburl.com/code/lp6u8j81
auto combined_per_sample_weights = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(pin_memory));
tbe_input_combine_with_length_cpu_out(
combined_indices,
combined_lengths,
combined_per_sample_weights,
indices_list,
lengths_list,
per_sample_weights);
return {combined_indices, combined_lengths, combined_per_sample_weights};
}

// Similar to tbe_input_combine_cpu, but padding all the offsets
Expand Down

0 comments on commit 2bd3222

Please sign in to comment.