Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remove serial_exec from scatter/gather kernel (pytorch#36181)
Summary: Since the indexed dimension in `scatter/gather` is traversed inside the kernel, all the memory conflicts of writing to the same memory between the threads are actually mutually disjoint. See [this comment](pytorch#33389 (comment)) for a graphical explanation. More formal description: Suppose we deal with 3D tensors and `dim=0`, hence the `scatter_add` operations are ``` self[index[i][j][k]][j][k] += src[i][j][k], ... self[index[i'][j'][k']][j'][k'] += src[i'][j'][k'], ... ``` Clearly, write/read to the same memory happens if and and only if: ``` index[i][j][k] = index[i'][j'][k'], j = j', k = k'. ``` Since the reduction over `dim=0` happens inside the kernel, threads `i` and `i'` partition `dim=1,2`. It means that threads `i` and `i'` receive indices ``` I = {(*, i, k) sent to the thread i}, I' = {(*, i', k') sent to the thread i'}, I intersection with I' = the empty set. ``` This happens: ``` index[i][j][k] = index[i'][j'][k'], j = j', k = k', ``` if and only if there exists some thread k which receives indices K and `(*,j,k),(*,j',k') in K`. Therefore it is possible to make `scatter_add` parallel and remove `serial_exec` from the `scatter_gather_base_kernel`. CC v0dro Pull Request resolved: pytorch#36181 Differential Revision: D21716167 Pulled By: ngimel fbshipit-source-id: 49aee2de43779a1f0b359c22c8589c0702ee68a2
- Loading branch information