Skip to content

Commit

Permalink
Adds support for accelerated sorting with x86-simd-sort (pytorch#127936)
Browse files Browse the repository at this point in the history
Adds x86-simd-sort as a submodule to accelerate sorting for 32-bit and 64-bit datatypes when AVX2 or AVX512 are available.

For contiguous data, this can be over a 10x speedup for large arrays. For discontiguous data, it can give over a 4x speedup with larger arrays. These benchmarks were gathered on a Skylake system (7900x), limited to 8 threads.

<details>
<summary><b>Contiguous Benchmarks</b></summary>

```
float32, normally distributed (in microseconds)
size           Default        AVX2           AVX512         Default/AVX2   Default/AVX512
16             7.150844336    6.886271477    7.132277489    1.038420335    1.002603214
128            9.208030939    8.478154898    7.846915245    1.086089019    1.173458697
1024           37.79037627    23.60707456    16.44122627    1.600807257    2.298513241
10000          714.7355628    203.9921844    105.5683001    3.503739934    6.770361577
100000         8383.074408    721.6333354    465.3709247    11.61680593    18.01374766
1000000        97124.31945    5632.054572    3920.148401    17.24491803    24.77567416
10000000       1161974.907    86070.48988    71533.82301    13.50027063    16.24371323

int32_t, uniformly distributed (in microseconds)
size           Default        AVX2           AVX512         Default/AVX2   Default/AVX512
16             7.203208685    6.92212224     7.014458179    1.040606975    1.026908779
128            8.972388983    8.195516348    7.592543125    1.094792396    1.18173698
1024           32.77489477    23.6874548     15.36617105    1.383639359    2.132925285
10000          607.8824128    193.3402024    99.25090471    3.144107667    6.124703997
100000         523.9384684    608.1836536    442.3166784    0.861480682    1.184532472
1000000        5211.348627    5271.598405    3518.861883    0.988570871    1.480975611
10000000       133853.6263    81463.05084    67852.97394    1.643120714    1.972700952
```

</details>

Note that the int32_t sort is accelerated by FBGEMM's radix sort for larger arrays, but this only handles contiguous data and in one sorting direction.

<details>
<summary><b>Discontiguous Benchmarks</b></summary>

```
float, normal distributed, discontiguous in sorted dimension (in microseconds)
size           Default        AVX2           AVX512         Default/AVX2   Default/AVX512
16             3.836543679    4.011214256    3.84376061     0.956454439    0.99812243
128            5.755310194    5.755723127    4.820394962    0.999928257    1.193949923
1024           49.46946019    24.78790785    15.47874362    1.995709379    3.195960952
10000          665.2505291    236.6165959    143.9490662    2.811512551    4.621429974
100000         4328.002203    1329.001212    818.3516414    3.256582586    5.288682743
1000000        47651.5018     16693.72045    11827.39551    2.854456677    4.028909133
10000000       556655.1288    236252.6258    184215.9828    2.356185998    3.021752621

int32_t, uniformly distributed, discontiguous in sorted dimension  (in microseconds)
size           Default        AVX2           AVX512         Default/AVX2   Default/AVX512
16             3.817994356    3.878117442    3.770039797    0.984496837    1.012719908
128            5.578731397    5.577152082    4.716770534    1.000283176    1.182743862
1024           43.3412619     23.61275801    14.55446819    1.835501887    2.977866408
10000          634.3997478    224.4322851    133.9518324    2.826686667    4.736028889
100000         4084.358152    1292.363303    781.7867576    3.16037924     5.22438902
1000000        46262.20465    16608.35284    11367.51817    2.785478192    4.06968381
10000000       541231.9104    235185.1861    180249.9294    2.301301028    3.002674742
```

</details>

Pull Request resolved: pytorch#127936
Approved by: https://github.com/jgong5, https://github.com/peterbell10
  • Loading branch information
sterrettm2 authored and pytorchmergebot committed Sep 20, 2024
1 parent d2455b9 commit 239a9ad
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@
[submodule "third_party/NVTX"]
path = third_party/NVTX
url = https://github.com/NVIDIA/NVTX.git
[submodule "third_party/x86-simd-sort"]
path = third_party/x86-simd-sort
url = https://github.com/intel/x86-simd-sort.git
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ else()
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
endif()
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_X86_SIMD_SORT "Use x86-simd-sort to accelerate sorting and topk for AVX2/AVX512" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
Expand Down Expand Up @@ -907,6 +908,13 @@ if(USE_FBGEMM)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
endif()

if(USE_X86_SIMD_SORT)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_X86_SIMD_SORT")
if(USE_XSS_OPENMP)
string(APPEND CMAKE_CXX_FLAGS " -DXSS_USE_OPENMP")
endif()
endif()

if(USE_PYTORCH_QNNPACK)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
endif()
Expand Down
34 changes: 34 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,37 @@ and reference the following license:
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

=======================================================================
x86-simd-sort BSD 3-Clause License
=======================================================================

Code derived from implementations in x86-simd-sort should mention its
derivation and reference the following license:

Copyright (c) 2022, Intel. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
131 changes: 129 additions & 2 deletions aten/src/ATen/native/cpu/SortingKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@
#include <ATen/native/CompositeRandomAccessor.h>
#include <ATen/native/TopKImpl.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/irange.h>

#ifdef USE_FBGEMM
#include <fbgemm/Utils.h>
#endif

#if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
#define XSS_COMPILE_TIME_SUPPORTED
#include <src/x86simdsort-static-incl.h>
#endif

namespace at::native {

namespace {
Expand Down Expand Up @@ -119,6 +126,7 @@ static void parallel_sort1d_kernel(
std::vector<int64_t> tmp_vals(elements);
const scalar_t* sorted_keys = nullptr;
const int64_t* sorted_vals = nullptr;

std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
keys,
vals,
Expand Down Expand Up @@ -167,6 +175,116 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
}
}

#if defined(XSS_COMPILE_TIME_SUPPORTED)

#define AT_DISPATCH_CASE_XSS_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)

#define AT_DISPATCH_XSS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))

static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
// xss_sort is not a stable sort
if (stable) return false;

auto type = values.scalar_type();
if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false;

return true;
}

static bool xss_sort_preferred(const TensorBase& values, const bool descending) {
#if defined(XSS_USE_OPENMP) || !defined(USE_FBGEMM)
return true;
#else
// Without OpenMP support for x86-simd-sort, fbgemm radix sort is faster when it can be used
return !can_use_radix_sort(values, descending);
#endif
}

static void xss_sort_kernel(
const TensorBase& values,
const TensorBase& indices,
int64_t dim,
bool descending) {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.resize_outputs(false)
.declare_static_shape(values.sizes(), /*squash_dims=*/dim)
.add_output(values)
.add_output(indices)
.build();

using index_t = int64_t;

AT_DISPATCH_XSS_TYPES(values.scalar_type(), "xss_sort_kernel", [&] {

auto values_dim_stride = values.stride(dim);
auto indices_dim_stride = indices.stride(dim);
auto dim_size = values.size(dim);

auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* values_data_bytes = data[0];
auto* indices_data_bytes = data[1];

if(values_data_bytes==nullptr || indices_data_bytes==nullptr){
return;
}

if (values_dim_stride == 1 && indices_dim_stride == 1){
for (const auto i C10_UNUSED : c10::irange(n)) {
x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
reinterpret_cast<scalar_t*>(values_data_bytes),
reinterpret_cast<index_t*>(indices_data_bytes),
dim_size,
true,
descending);

values_data_bytes += strides[0];
indices_data_bytes += strides[1];
}
}else{
c10::SmallBuffer<scalar_t, 0> tmp_values(dim_size);
c10::SmallBuffer<index_t, 0> tmp_indices(dim_size);

for (const auto i : c10::irange(n)) {
TensorAccessor<scalar_t, 1> mode_values_acc(
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
&dim_size, &values_dim_stride);
TensorAccessor<index_t, 1> mode_indices_acc(
reinterpret_cast<index_t*>(data[1] + i * strides[1]),
&dim_size, &indices_dim_stride);

for (const auto j : c10::irange(dim_size)) {
tmp_values[j] = mode_values_acc[j];
tmp_indices[j] = j;
}

x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
tmp_values.data(),
tmp_indices.data(),
dim_size,
true,
descending);

for (const auto j : c10::irange(dim_size)) {
mode_values_acc[j] = tmp_values[j];
mode_indices_acc[j] = tmp_indices[j];
}
}
}
};

int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
iter.for_each(loop, /*grain_size=*/grain_size);

});
}
#endif

static void sort_kernel(
const TensorBase& self,
const TensorBase& values,
Expand All @@ -181,6 +299,14 @@ static void sort_kernel(
// https://github.com/pytorch/pytorch/issues/91420
return;
}

#if defined(XSS_COMPILE_TIME_SUPPORTED)
if (can_use_xss_sort(values, indices, dim, stable) && xss_sort_preferred(values, descending)){
xss_sort_kernel(values, indices, dim, descending);
return;
}
#endif

#ifdef USE_FBGEMM
if (can_use_radix_sort(values, descending)) {
parallel_sort1d_kernel(values, indices);
Expand Down Expand Up @@ -232,6 +358,7 @@ static void topk_kernel(
int64_t dim,
bool largest,
bool sorted) {

auto sizes = self.sizes();
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
Expand Down Expand Up @@ -266,7 +393,7 @@ static void topk_kernel(

} // anonymous namespace

REGISTER_DISPATCH(sort_stub, &sort_kernel);
REGISTER_DISPATCH(topk_stub, &topk_kernel);
ALSO_REGISTER_AVX512_DISPATCH(sort_stub, &sort_kernel);
ALSO_REGISTER_AVX512_DISPATCH(topk_stub, &topk_kernel);

} //at::native
22 changes: 22 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,28 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
endif()

# --[ x86-simd-sort integration
if(USE_X86_SIMD_SORT)
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
message(WARNING
"x64 operating system is required for x86-simd-sort. "
"Not compiling with x86-simd-sort. "
"Turn this warning off by USE_X86_SIMD_SORT=OFF.")
set(USE_X86_SIMD_SORT OFF)
endif()

if(USE_X86_SIMD_SORT)
if(USE_OPENMP AND NOT MSVC)
set(USE_XSS_OPENMP ON)
else()
set(USE_XSS_OPENMP OFF)
endif()

set(XSS_SIMD_SORT_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/x86-simd-sort)
include_directories(SYSTEM ${XSS_SIMD_SORT_INCLUDE_DIR})
endif()
endif()

# --[ ATen checks
set(USE_LAPACK 0)

Expand Down
1 change: 1 addition & 0 deletions cmake/Summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ function(caffe2_print_configuration_summary)
endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
message(STATUS " USE_X86_SIMD_SORT : ${USE_X86_SIMD_SORT}")
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}")
message(STATUS " USE_KINETO : ${USE_KINETO}")
Expand Down
9 changes: 9 additions & 0 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0},
# High atol due to precision loss
("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0},
# reference_in_float can cause erroneous failures in sorting tests
"argsort": {"reference_in_float": False},
"sort": {"reference_in_float": False},
}

inductor_override_kwargs["cuda"] = {
Expand Down Expand Up @@ -536,6 +539,9 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
("index_reduce.amax", f32): {"check_gradient": False},
("index_reduce.amax", f16): {"check_gradient": False},
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
# reference_in_float can cause erroneous failures in sorting tests
"argsort": {"reference_in_float": False},
"sort": {"reference_in_float": False},
}

inductor_override_kwargs["xpu"] = {
Expand Down Expand Up @@ -655,6 +661,9 @@ def wrapper_noop_set_seed(op, *args, **kwargs):
("nn.functional.embedding_bag", f64): {"check_gradient": False},
("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3},
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3},
# reference_in_float can cause erroneous failures in sorting tests
"argsort": {"reference_in_float": False},
"sort": {"reference_in_float": False},
}

# Test with one sample only for following ops
Expand Down
1 change: 1 addition & 0 deletions third_party/x86-simd-sort
Submodule x86-simd-sort added at 9a1b61

0 comments on commit 239a9ad

Please sign in to comment.