Skip to content

Commit

Permalink
remove serial_exec from scatter/gather kernel (pytorch#36181)
Browse files Browse the repository at this point in the history
Summary:
Since the indexed dimension in `scatter/gather` is traversed inside the kernel, all the memory conflicts of writing to the same memory between the threads are actually mutually disjoint.
See [this comment](pytorch#33389 (comment)) for a graphical explanation. More formal description:
Suppose we deal with 3D tensors and `dim=0`, hence the `scatter_add` operations are
```
self[index[i][j][k]][j][k] += src[i][j][k],
...
self[index[i'][j'][k']][j'][k'] += src[i'][j'][k'],
...
```
Clearly, write/read to the same memory happens if and and only if:
```
index[i][j][k] = index[i'][j'][k'],
j = j',
k = k'.
```
Since the reduction over `dim=0` happens inside the kernel, threads `i` and `i'` partition `dim=1,2`. It means that threads `i` and `i'` receive indices
```
I = {(*, i, k) sent to the thread i},
I' = {(*, i', k') sent to the thread i'},
I intersection with I' = the empty set.
```

This happens:
```
index[i][j][k] = index[i'][j'][k'],
j = j',
k = k',
```
if and only if there exists some thread k which receives indices K and
`(*,j,k),(*,j',k') in K`.

Therefore it is possible to make `scatter_add` parallel and remove `serial_exec` from the `scatter_gather_base_kernel`.

CC v0dro
Pull Request resolved: pytorch#36181

Differential Revision: D21716167

Pulled By: ngimel

fbshipit-source-id: 49aee2de43779a1f0b359c22c8589c0702ee68a2
  • Loading branch information
nikitaved authored and facebook-github-bot committed May 27, 2020
1 parent b636f5e commit 3006334
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions aten/src/ATen/native/cpu/ScatterGatherKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ struct cpu_scatter_gather_base_kernel {
Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const std::string& method_name,
const func_t& f,
bool serial_exec = true
const func_t& f
) {
// no-op if index is empty
if (index.numel() == 0) {
Expand Down Expand Up @@ -143,12 +142,7 @@ struct cpu_scatter_gather_base_kernel {

};

if (serial_exec) {
iter.serial_for_each(loop, {0, iter.numel()});
}
else {
iter.for_each(loop);
}
iter.for_each(loop);
}
);
}
Expand All @@ -159,8 +153,7 @@ void gather_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim, const Te
result, dim, index, self,
"gather_out_cpu", [] (auto* lhs, const auto* rhs) {
*lhs = *rhs;
},
/*serial_exec=*/false
}
);
}

Expand All @@ -169,8 +162,7 @@ void scatter_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, const Te
self, dim, index, src,
"scatter_cpu_", [] (auto* lhs, const auto* rhs) {
*lhs = *rhs;
},
/*serial_exec=*/false
}
);
}

Expand All @@ -180,8 +172,7 @@ void scatter_fill_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, Sca
"scatter_fill_cpu_", [src] (auto* lhs, const auto* rhs) {
using scalar_t = typename std::remove_pointer<decltype(lhs)>::type;
*lhs = src.to<scalar_t>();
},
/*serial_exec=*/false
}
);
}

Expand All @@ -190,8 +181,7 @@ void scatter_add_cpu_kernel(Tensor& self, int64_t dim, const Tensor& index, cons
self, dim, index, src,
"scatter_add_", [] (auto* lhs, const auto* rhs) {
*lhs += *rhs;
},
/*serial_exec=*/true
}
);
}

Expand Down

0 comments on commit 3006334

Please sign in to comment.