Skip to content

Commit

Permalink
Extending the Pytorch vec backend for SVE (ARM) (pytorch#119571)
Browse files Browse the repository at this point in the history
**Motivation:**
In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes.

**Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec

**This PR:**

* Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE.

* More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions)

* We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec.

* Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well.

Pull Request resolved: pytorch#119571
Approved by: https://github.com/malfet, https://github.com/snadampal

Co-authored-by: Divya Kotadiya <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Sep 18, 2024
1 parent bad6904 commit 5a6ddbc
Show file tree
Hide file tree
Showing 29 changed files with 2,554 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
endif()
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})

file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ std::string get_cpu_capability() {
return "DEFAULT";
case native::CPUCapability::ZVECTOR:
return "Z VECTOR";
#elif defined(HAVE_SVE_CPU_DEFINITION)
case native::CPUCapability::DEFAULT:
return "DEFAULT";
case native::CPUCapability::SVE256:
return "SVE256";
#else
case native::CPUCapability::DEFAULT:
return "NO AVX";
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/functional_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct VecReduceAllSIMD<float, Op> {
#endif // defined(CPU_CAPABILITY_AVX512)
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)

#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/cpu/vec/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* Clang-compatible compiler, targeting arm neon */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* CLANG-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#elif defined(_MSC_VER)
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
Expand All @@ -17,6 +21,10 @@
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
/* GCC-compatible compiler, targeting ARM with NEON */
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
/* GCC-compatible compiler, targeting ARM with SVE */
#include <arm_sve.h>
#endif
#if defined (MISSING_ARM_VLD1)
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
#elif defined (MISSING_ARM_VST1)
Expand Down
63 changes: 63 additions & 0 deletions aten/src/ATen/cpu/vec/sve/sve_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <ATen/cpu/vec/intrinsics.h>

#include <ATen/cpu/vec/vec_base.h>

#if defined(CPU_CAPABILITY_SVE)

// Define the data type of VLS(vector-length specific).
typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));

#define ptrue svptrue_b8()
#define ZERO_S8 svdup_n_s8(0)
#define ZERO_S16 svdup_n_s16(0)
#define ZERO_S32 svdup_n_s32(0)
#define ZERO_S64 svdup_n_s64(0)
#define ZERO_U8 svdup_n_u8(0)
#define ZERO_U16 svdup_n_u16(0)
#define ZERO_U32 svdup_n_u32(0)
#define ZERO_U64 svdup_n_u64(0)
#define ZERO_F16 svdup_n_f16(0.f)
#define ZERO_F32 svdup_n_f32(0.f)
#define ZERO_F64 svdup_n_f64(0.0)
#define ONE_S8 svdup_n_s8(1)
#define ONE_S16 svdup_n_s16(1)
#define ONE_S32 svdup_n_s32(1)
#define ONE_S64 svdup_n_s64(1)
#define ONE_U8 svdup_n_u8(1)
#define ONE_U16 svdup_n_u16(1)
#define ONE_U32 svdup_n_u32(1)
#define ONE_U64 svdup_n_u64(1)
#define ONE_F16 svdup_n_f16(1.f)
#define ONE_F32 svdup_n_f32(1.f)
#define ONE_F64 svdup_n_f64(1.0)
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
#define ALL_S8_FALSE_MASK svdup_n_s8(0x0)
#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff)
#define ALL_S16_FALSE_MASK svdup_n_s16(0x0)
#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff)
#define ALL_S32_FALSE_MASK svdup_n_s32(0x0)
#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff)
#define ALL_S64_FALSE_MASK svdup_n_s64(0x0)
#define ALL_U8_TRUE_MASK svdup_n_u8(0x01)
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)

#endif // defined(CPU_CAPABILITY_SVE)
176 changes: 176 additions & 0 deletions aten/src/ATen/cpu/vec/sve/vec_common_sve.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#pragma once

// DO NOT DEFINE STATIC DATA IN THIS HEADER!
// See Note [Do not compile initializers with SVE]

#include <ATen/cpu/vec/intrinsics.h>

#include <ATen/cpu/vec/vec_base.h>
#include <ATen/cpu/vec/sve/sve_helper.h>

#if defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/sve/vec_float.h>
#include <ATen/cpu/vec/sve/vec_double.h>
#include <ATen/cpu/vec/sve/vec_int.h>
#include <ATen/cpu/vec/sve/vec_qint.h>
#endif

namespace at {
namespace vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_SVE)

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<>
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
return svreinterpret_f32_f64(src);
}

template<>
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
return svreinterpret_f64_f32(src);
}

#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
template<> \
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
return svreinterpret_s##int_bit##_f##float_bit(src); \
} \
template<> \
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
return svreinterpret_f##float_bit##_s##int_bit(src); \
}

DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) {
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
return svld1_gather_s64index_f64(ptrue, base_addr, vindex);
}

template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) {
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
return svld1_gather_s32index_f32(ptrue, base_addr, vindex);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) {
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_),
ALL_S64_TRUE_MASK);
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src);
}

template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) {
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_),
ALL_S32_TRUE_MASK);
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// Only works for inputs in the range: [-2^51, 2^51]
// From: https://stackoverflow.com/a/41148578
template<>
Vectorized<int64_t>
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000));
return svsub_s64_x(ptrue,
svreinterpret_s64_f64(x),
svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000)));
}

template<>
Vectorized<int32_t>
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
return svcvt_s32_f32_x(ptrue, src);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, a1, a3, a3}
// b = {b0, b1, b2, b3}
// group cols crossing lanes:
// return {a0, b0, a1, b1}
// {a2, b2, a3, b3}
return std::make_pair(Vectorized<double>(svzip1_f64(a, b)),
Vectorized<double>(svzip2_f64(a, b)));
}

template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
// group cols crossing lanes:
// return {a0, b0, a1, b1, a2, b2, a3, b3}
// {a4, b4, a5, b5, a6, b6, a7, b7}
return std::make_pair(Vectorized<float>(svzip1_f32(a, b)),
Vectorized<float>(svzip2_f32(a, b)));
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <>
std::pair<Vectorized<double>, Vectorized<double>>
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
// inputs:
// a = {a0, b0, a1, b1}
// b = {a2, b2, a3, b3}
// swap lanes:
// return {a0, a1, a2, a3}
// {b0, b1, b2, b3}
return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)),
Vectorized<double>(svuzp2_f64(a, b)));
}

template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
// inputs:
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
// swap lanes:
// return {a0, a1, a2, a3, a4, a5, a6, a7}
// {b0, b1, b2, b3, b4, b5, b6, b7}
return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)),
Vectorized<float>(svuzp2_f32(a, b)));
}

#endif // defined(CPU_CAPABILITY_SVE)

}}}
Loading

0 comments on commit 5a6ddbc

Please sign in to comment.