Skip to content

Commit

Permalink
unfold_backward: Remove stride >= size kernel in favour of copy_ (pyt…
Browse files Browse the repository at this point in the history
…orch#88061)

unfold_backward has a dedicated kernel for `stride >= size` which uses temporary
tensors created by `at::arange` to perform the mapping from unfolded to folded.
This instead uses `unfold` to view the output, and does a direct copy from the
gradient into the view.

In benchmarks I see either no difference or a marginal speed benefit from
this PR.
Pull Request resolved: pytorch#88061
Approved by: https://github.com/albanD
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Oct 31, 2022
1 parent ceddcf5 commit 3eb3790
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 194 deletions.
6 changes: 6 additions & 0 deletions aten/src/ATen/native/UnfoldBackward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/unfold_backward_native.h>
#include <ATen/ops/zeros.h>
#endif
Expand All @@ -21,6 +22,11 @@ Tensor unfold_backward(
int64_t step
) {
auto grad_input = at::zeros(input_sizes, grad.options());
if (step >= size) {
auto gI_unfolded = grad_input.unfold(dim, size, step);
gI_unfolded.copy_(grad);
return grad_input;
}

unfold_backward_stub(
grad.device().type(),
Expand Down
73 changes: 0 additions & 73 deletions aten/src/ATen/native/UnfoldBackward.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,79 +107,6 @@ static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out(
return iter;
}

static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_in(
Tensor& grad_out,
const Tensor& grad_in,
int64_t dim,
int64_t /*size*/,
int64_t /*step*/
) {
dim = maybe_wrap_dim(dim, grad_out.dim());
// last dim stores the folds
auto last_dim = maybe_wrap_dim(-1, grad_in.dim());

auto grad_in_dim = ensure_nonempty_dim(grad_in.dim());
auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
auto grad_in_last_dim_size = ensure_nonempty_size(grad_in, last_dim);

/* prepare grad_out for TensorIterator { */
auto grad_out_restrided = grad_out.unsqueeze(-1);

auto grad_out_strides = ensure_nonempty_vec(grad_out_restrided.strides().vec());
auto grad_out_sizes = ensure_nonempty_vec(grad_out_restrided.sizes().vec());

grad_out_strides[dim] = 0;
grad_out_strides[last_dim] = 0;

grad_out_sizes[dim] = grad_in_dim_size;
grad_out_sizes[last_dim] = grad_in_last_dim_size;

grad_out_restrided = grad_out_restrided.as_strided(grad_out_sizes, grad_out_strides);
/* } */

// for each element grad_out[i_1,...,i_dim,...,i_last_dim]
// we have to know i_dim and i_last_dim.
// This information is stored in Tensors
// idx_dim and idx_last_dim
/* prepare idx_dim and idx_last_dim for TensorIterator { */
auto idx_dim = at::arange(
0, grad_in_dim_size, grad_in.options().dtype(at::kLong)
);

auto idx_dim_strides = std::vector<int64_t>(grad_in_dim, 0);
auto idx_dim_sizes = std::vector<int64_t>(grad_in_dim, 1);

idx_dim_strides[dim] = 1;
idx_dim_sizes[dim] = grad_in_dim_size;

auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);

auto idx_last_dim = at::arange(
0, grad_in_last_dim_size, grad_in.options().dtype(at::kLong)
);

auto idx_last_dim_strides = std::vector<int64_t>(grad_in_dim, 0);
auto idx_last_dim_sizes = std::vector<int64_t>(grad_in_dim, 1);

idx_last_dim_strides[last_dim] = 1;
idx_last_dim_sizes[last_dim] = grad_in_last_dim_size;

auto idx_last_dim_restrided = idx_last_dim.as_strided(idx_last_dim_sizes, idx_last_dim_strides);
/* } */

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_owned_output(grad_out_restrided)
.add_owned_input(grad_in)
.add_owned_input(idx_dim_restrided)
.add_owned_input(idx_last_dim_restrided)
.build();

return iter;
}

}

}} // namespace at::native
81 changes: 25 additions & 56 deletions aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ void _unfold_backward_internal_kernel(
int64_t grad_in_dim_stride,
int64_t grad_in_last_dim_stride,
int64_t grad_in_dim_size,
int64_t grad_out_dim_stride,
bool is_step_ge_size
int64_t grad_out_dim_stride
) {
if (iter.numel() == 0) {
return;
Expand All @@ -78,53 +77,32 @@ void _unfold_backward_internal_kernel(
auto* RESTRICT grad_in_ptr = data[1];
auto* RESTRICT idx_dim_ptr = data[2];

if (is_step_ge_size) {
auto* RESTRICT idx_last_dim_ptr = data[3];
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);

for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);
auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);

auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);
auto idx_last_dim = *reinterpret_cast<int64_t*>(idx_last_dim_ptr);
// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}

auto grad_out_idx_dim = idx_dim * step + idx_last_dim;
grad_out_data[grad_out_idx_dim * grad_out_dim_stride] = *grad_in_data;
auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size)
? (grad_in_dim_size - 1) : right_fold_idx;

grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
idx_last_dim_ptr += strides[3];
}
}
else {
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);

auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);

// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}

auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size)
? (grad_in_dim_size - 1) : right_fold_idx;

for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}

grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}

grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
}
};

Expand All @@ -148,16 +126,8 @@ void unfold_backward_cpu_kernel(

auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);

auto is_step_ge_size = (step >= size);

TensorIterator iter =
is_step_ge_size ?
_make_unfold_backward_iter_over_grad_in(
grad_out, grad_in, dim, size, step
) :
_make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step
);
TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step);

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
Expand All @@ -170,8 +140,7 @@ void unfold_backward_cpu_kernel(
grad_in_dim_stride,
grad_in_last_dim_stride,
grad_in_dim_size,
grad_out_dim_stride,
is_step_ge_size
grad_out_dim_stride
);
}
);
Expand Down
95 changes: 30 additions & 65 deletions aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ void _unfold_backward_internal_kernel(
int64_t grad_in_dim_stride,
int64_t grad_in_last_dim_stride,
int64_t grad_in_dim_size,
int64_t grad_out_dim_stride,
bool is_step_ge_size
int64_t grad_out_dim_stride
) {
if (iter.numel() == 0) {
return;
Expand All @@ -74,8 +73,7 @@ void _unfold_backward_internal_kernel(
grad_in_dim_stride,
grad_in_last_dim_stride,
grad_in_dim_size,
grad_out_dim_stride,
is_step_ge_size
grad_out_dim_stride
);
}
return;
Expand All @@ -85,63 +83,39 @@ void _unfold_backward_internal_kernel(
char* __restrict__ grad_in_ptr = reinterpret_cast<char*>(iter.data_ptr(1));
char* __restrict__ idx_dim_ptr = reinterpret_cast<char*>(iter.data_ptr(2));

if (is_step_ge_size) {
char* __restrict__ idx_last_dim_ptr = reinterpret_cast<char*>(iter.data_ptr(3));
auto offset_calc = make_offset_calculator<3>(iter);

auto offset_calc = make_offset_calculator<4>(iter);
// The algorithm is: for each index in grad_out find
// the elements contributing to it and sum them up.
// Note: the algorithm does not require any synchronization.
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);

// this loop simply copies the data
// from proper places in grad_out to grad_in
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
auto* __restrict__ grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr + offsets[0]);
auto* __restrict__ grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr + offsets[1]);

auto* __restrict__ grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr + offsets[0]);
auto* __restrict__ grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr + offsets[1]);
auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr + offsets[2]);

auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr + offsets[2]);
auto idx_last_dim = *reinterpret_cast<int64_t*>(idx_last_dim_ptr + offsets[3]);

auto grad_out_idx_dim = idx_dim * step + idx_last_dim;
grad_out_data[grad_out_idx_dim * grad_out_dim_stride] = *grad_in_data;
};

_launch_unfold_backward_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
else {
auto offset_calc = make_offset_calculator<3>(iter);

// The algorithm is: for each index in grad_out find
// the elements contributing to it and sum them up.
// Note: the algorithm does not require any synchronization.
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);

auto* __restrict__ grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr + offsets[0]);
auto* __restrict__ grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr + offsets[1]);

auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr + offsets[2]);

// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}
// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}

auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size) ?
(grad_in_dim_size - 1) : right_fold_idx;
auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size) ?
(grad_in_dim_size - 1) : right_fold_idx;

for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}
for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}

};
};

_launch_unfold_backward_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
_launch_unfold_backward_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}

void unfold_backward_cuda_kernel(
Expand All @@ -161,16 +135,8 @@ void unfold_backward_cuda_kernel(

auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);

auto is_step_ge_size = (step >= size);

TensorIterator iter =
is_step_ge_size ?
_make_unfold_backward_iter_over_grad_in(
grad_out, grad_in, dim, size, step
) :
_make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step
);
TensorIterator iter = _make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step);

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
Expand All @@ -183,8 +149,7 @@ void unfold_backward_cuda_kernel(
grad_in_dim_stride,
grad_in_last_dim_stride,
grad_in_dim_size,
grad_out_dim_stride,
is_step_ge_size
grad_out_dim_stride
);
}
);
Expand Down

0 comments on commit 3eb3790

Please sign in to comment.