From 30063347e71bc540737889895ed3212cf18683b3 Mon Sep 17 00:00:00 2001 From: Nik Ved Date: Wed, 27 May 2020 13:24:14 -0700 Subject: [PATCH] remove serial_exec from scatter/gather kernel (#36181) Summary: Since the indexed dimension in `scatter/gather` is traversed inside the kernel, all the memory conflicts of writing to the same memory between the threads are actually mutually disjoint. See [this comment](https://github.com/pytorch/pytorch/issues/33389#issuecomment-590017938) for a graphical explanation. More formal description: Suppose we deal with 3D tensors and `dim=0`, hence the `scatter_add` operations are ``` self[index[i][j][k]][j][k] += src[i][j][k], ... self[index[i'][j'][k']][j'][k'] += src[i'][j'][k'], ... ``` Clearly, write/read to the same memory happens if and and only if: ``` index[i][j][k] = index[i'][j'][k'], j = j', k = k'. ``` Since the reduction over `dim=0` happens inside the kernel, threads `i` and `i'` partition `dim=1,2`. It means that threads `i` and `i'` receive indices ``` I = {(*, i, k) sent to the thread i}, I' = {(*, i', k') sent to the thread i'}, I intersection with I' = the empty set. ``` This happens: ``` index[i][j][k] = index[i'][j'][k'], j = j', k = k', ``` if and only if there exists some thread k which receives indices K and `(*,j,k),(*,j',k') in K`. Therefore it is possible to make `scatter_add` parallel and remove `serial_exec` from the `scatter_gather_base_kernel`. CC v0dro Pull Request resolved: https://github.com/pytorch/pytorch/pull/36181 Differential Revision: D21716167 Pulled By: ngimel fbshipit-source-id: 49aee2de43779a1f0b359c22c8589c0702ee68a2 --- .../ATen/native/cpu/ScatterGatherKernel.cpp | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index ece24adcd87148..5b26efa10b7401 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -44,8 +44,7 @@ struct cpu_scatter_gather_base_kernel { Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const std::string& method_name, - const func_t& f, - bool serial_exec = true + const func_t& f ) { // no-op if index is empty if (index.numel() == 0) { @@ -143,12 +142,7 @@ struct cpu_scatter_gather_base_kernel { }; - if (serial_exec) { - iter.serial_for_each(loop, {0, iter.numel()}); - } - else { - iter.for_each(loop); - } + iter.for_each(loop); } ); } @@ -159,8 +153,7 @@ void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Te result, dim, index, self, "gather_out_cpu", [] (auto* lhs, const auto* rhs) { *lhs = *rhs; - }, - /*serial_exec=*/false + } ); } @@ -169,8 +162,7 @@ void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Te self, dim, index, src, "scatter_cpu_", [] (auto* lhs, const auto* rhs) { *lhs = *rhs; - }, - /*serial_exec=*/false + } ); } @@ -180,8 +172,7 @@ void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Sca "scatter_fill_cpu_", [src] (auto* lhs, const auto* rhs) { using scalar_t = typename std::remove_pointer::type; *lhs = src.to(); - }, - /*serial_exec=*/false + } ); } @@ -190,8 +181,7 @@ void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, cons self, dim, index, src, "scatter_add_", [] (auto* lhs, const auto* rhs) { *lhs += *rhs; - }, - /*serial_exec=*/true + } ); }