Skip to content

Commit

Permalink
handle truncation vs. padding case in jagged tensors (pytorch#1013)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1013

When max_L is smaller than offset[i + 1] - offset[i], jagged -> dense will do truncation vs. padding.
When we do jagged + dense -> jagged, the truncated portion where we don't have corresponding element in the dense input tensor, we just output the jagged input as if the value from dense input tensor is zero.
We can so similar for multiplication, outputting zero for truncated portion.

Reviewed By: jasonjk-park

Differential Revision: D35171324

fbshipit-source-id: 38381d029be39b17530c9438d9cb9c94ec5d1bcf
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 29, 2022
1 parent 073ea44 commit a0ff6de
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 75 deletions.
44 changes: 20 additions & 24 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,9 @@ std::tuple<dim3, dim3, Tensor> check_shape_and_partition_(
values.size(-1));
const int jagged_folded_size =
dense_tensor.numel() / (outer_dense_size * inner_dense_size);
const int jagged_innermost_size = dense_tensor.size(-2);

const int threads_x = jagged_innermost_size >= kWarpSize / 2
? kWarpSize
: jagged_innermost_size;
const int threads_x =
inner_dense_size >= kWarpSize / 2 ? kWarpSize : inner_dense_size;
const int threads_y = kMaxThreads / kWarpSize;
const dim3 blocks(
div_round_up(outer_dense_size * jagged_folded_size, threads_y));
Expand Down Expand Up @@ -332,18 +330,6 @@ void jagged_dense_elementwise_jagged_output_(
#undef INVOKE_KERNEL_WITH_DIM
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_jagged_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y,
F f) {
Tensor output = at::empty_like(x_values);
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values, x_offsets, y, output, f);
return output;
}

class JaggedToPaddedDenseGPUOp
: public torch::autograd::Function<JaggedToPaddedDenseGPUOp> {
public:
Expand Down Expand Up @@ -415,6 +401,8 @@ class JaggedToPaddedDenseGPUOp
device_guard.set_index(grad_padded_values.get_device());

int32_t D = grad_padded_values.size(-1);
// Initialize with zeros so output will be zero for the portion truncated
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
Expand Down Expand Up @@ -494,7 +482,7 @@ class JaggedDenseAddGPUOp
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());

Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options());
Tensor x_values_grad = at::zeros(x_values_shape, grad_outputs[0].options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values_grad.scalar_type(), "jagged_dense_add_backward", [&] {
Expand Down Expand Up @@ -540,13 +528,17 @@ class JaggedDenseAddJaggedOutputGPUOp
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
// Initialize with jagged input so output will have the same value as the
// jagged tensor if there's no corresponding value in the dense tensor.
Tensor output = x_values.clone();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_dense_add_forward", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
Expand Down Expand Up @@ -714,13 +706,16 @@ class JaggedDenseMulGPUOp
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
// Initialize with zero so output will be zero if there's no corresponding
// value in the dense tensor.
Tensor output = at::zeros_like(x_values);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
Expand Down Expand Up @@ -748,15 +743,16 @@ class JaggedDenseMulGPUOp
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());

Tensor x_values_grad;
Tensor x_values_grad = at::zeros_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
x_values_grad = jagged_dense_elementwise_jagged_output_<scalar_t>(
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
y,
x_values_grad,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
Expand Down Expand Up @@ -975,7 +971,7 @@ class BatchedDenseVecJagged2DMulGPUOp
const int B = a_offsets.numel() - 1;
const int D = grad_outputs[0].size(-1);
Tensor a_values_grad = at::empty_like(a_values);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);
if (B > 0 && D > 0) {
Expand Down
67 changes: 34 additions & 33 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void jagged_dense_elementwise_dense_output_kernel_(
const int begin = x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base];
const int end =
x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base + 1];
for (; jiidx < end - begin; ++jiidx) {
for (; jiidx < std::min(end - begin, jagged_innermost_size); ++jiidx) {
int jidx = joidx * jagged_innermost_size + jiidx;
if (NO_INNER_DENSE) {
output_accessor[oidx][jidx][0] =
Expand Down Expand Up @@ -303,12 +303,13 @@ void jagged_dense_elementwise_jagged_output_kernel_(

// As a perf optimization, a separate loop level for the inner-most
// jagged dimension.
int jiidx = 0;
if (!is_zero) {
const int begin = x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base];
const int end =
x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base + 1];
for (; jiidx < end - begin; ++jiidx) {
for (int jiidx = 0;
jiidx < std::min(end - begin, jagged_innermost_size);
++jiidx) {
int jidx = joidx * jagged_innermost_size + jiidx;
if (NO_INNER_DENSE) {
output_accessor[begin + jiidx][0] =
Expand Down Expand Up @@ -354,18 +355,6 @@ void jagged_dense_elementwise_jagged_output_(
#undef INVOKE_KERNEL_WITH_DIM
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_jagged_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y,
F f) {
Tensor output = at::empty_like(x_values);
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values, x_offsets, y, output, f);
return output;
}

class JaggedToPaddedDenseCPUOp
: public torch::autograd::Function<JaggedToPaddedDenseCPUOp> {
public:
Expand Down Expand Up @@ -431,7 +420,9 @@ class JaggedToPaddedDenseCPUOp
auto grad_padded_values = grad_outputs[0];

int32_t D = grad_padded_values.size(-1);
auto grad_values = at::empty({total_L, D}, grad_padded_values.options());
// Initialize with zeros so output will be zero for the portion truncated
// in forward.
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
Expand All @@ -455,7 +446,6 @@ class JaggedToPaddedDenseCPUOp
}
};

// Almost identical copy of jagged_to_padded_dense in jagged_tensor_ops_cpu.cpp
Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
Expand Down Expand Up @@ -495,7 +485,7 @@ class JaggedDenseAddCPUOp
auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options());
Tensor x_values_grad = at::zeros(x_values_shape, grad_outputs[0].options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values_grad.scalar_type(), "jagged_scalars", [&] {
Expand Down Expand Up @@ -536,13 +526,18 @@ class JaggedDenseJaggedOutputAddCPUOp
ctx->save_for_backward(x_offsets);
ctx->saved_data["y_shape"] = y.sizes();

Tensor output;
// Initialize with jagged input so output will have the same value as the
// jagged tensor if there's no corresponding value in the dense tensor.
Tensor output = x_values.clone();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[](scalar_t x, scalar_t y) -> scalar_t { return x + y; });
});

return {output};
Expand Down Expand Up @@ -658,7 +653,7 @@ void jagged_jagged_elementwise_dense_output_kernel_(
const int begin = x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base];
const int end =
x_offsets_accessors[NUM_JAGGED_DIM - 1][offset_base + 1];
for (; jiidx < end - begin; ++jiidx) {
for (; jiidx < std::min(end - begin, jagged_innermost_size); ++jiidx) {
int jidx = joidx * jagged_innermost_size + jiidx;
if (NO_INNER_DENSE) {
output_accessor[oidx][jidx][0] =
Expand Down Expand Up @@ -728,13 +723,17 @@ class JaggedDenseMulCPUOp
tensors_to_save.push_back(y);
ctx->save_for_backward(tensors_to_save);

Tensor output;
// Initialize with zero so output will be zero if there's no corresponding
// value in the dense tensor.
Tensor output = at::zeros_like(x_values);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
output,
[](scalar_t x, scalar_t y) -> scalar_t { return x * y; });
});

return {output};
Expand All @@ -753,15 +752,16 @@ class JaggedDenseMulCPUOp
const Tensor y = ctx->get_saved_variables().back();
TORCH_CHECK(grad_outputs.size() == 1);

Tensor x_values_grad;
Tensor x_values_grad = at::zeros_like(grad_outputs[0]);
Tensor y_grad = at::empty_like(y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
x_values_grad = jagged_dense_elementwise_jagged_output_<scalar_t>(
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
y,
x_values_grad,
[](scalar_t x, scalar_t y) -> scalar_t { return x * y; });

jagged_jagged_elementwise_dense_output_<scalar_t>(
Expand Down Expand Up @@ -873,14 +873,15 @@ void outer_prod_jagged_2d_output(
at::TensorAccessor<scalar_t, 2> output_values) {
const int B = offsets.size(0) - 1;
const int H = x.size(0) / B;
const int max_L = x.size(1);
const int D = y.size(1);

for (int b = 0; b < B; ++b) {
const int row_start = offsets[b];
const int row_end = offsets[b + 1];
const int length = row_end - row_start;
for (int h = 0; h < H; ++h) {
for (int l = 0; l < length; ++l) {
for (int l = 0; l < std::min(length, max_L); ++l) {
for (int d = 0; d < D; ++d) {
output_values[row_start + l][h * D + d] =
x[b * H + h][l] * y[b * H + h][d];
Expand Down Expand Up @@ -949,7 +950,7 @@ class BatchedDenseVecJagged2DMulCPUOp

TENSOR_ON_CPU(grad_outputs[0]);

Tensor a_values_grad = at::empty_like(a_values);
Tensor a_values_grad = at::zeros_like(a_values);
Tensor v_grad = at::empty_like(v);

const int B = a_offsets.numel() - 1;
Expand Down
Loading

0 comments on commit a0ff6de

Please sign in to comment.