forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extending the Pytorch vec backend for SVE (ARM) (pytorch#119571)
**Motivation:** In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes. **Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec **This PR:** * Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE. * More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions) * We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec. * Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well. Pull Request resolved: pytorch#119571 Approved by: https://github.com/malfet, https://github.com/snadampal Co-authored-by: Divya Kotadiya <[email protected]>
- Loading branch information
1 parent
bad6904
commit 5a6ddbc
Showing
29 changed files
with
2,554 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#pragma once | ||
|
||
#include <ATen/cpu/vec/intrinsics.h> | ||
|
||
#include <ATen/cpu/vec/vec_base.h> | ||
|
||
#if defined(CPU_CAPABILITY_SVE) | ||
|
||
// Define the data type of VLS(vector-length specific). | ||
typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); | ||
|
||
#define ptrue svptrue_b8() | ||
#define ZERO_S8 svdup_n_s8(0) | ||
#define ZERO_S16 svdup_n_s16(0) | ||
#define ZERO_S32 svdup_n_s32(0) | ||
#define ZERO_S64 svdup_n_s64(0) | ||
#define ZERO_U8 svdup_n_u8(0) | ||
#define ZERO_U16 svdup_n_u16(0) | ||
#define ZERO_U32 svdup_n_u32(0) | ||
#define ZERO_U64 svdup_n_u64(0) | ||
#define ZERO_F16 svdup_n_f16(0.f) | ||
#define ZERO_F32 svdup_n_f32(0.f) | ||
#define ZERO_F64 svdup_n_f64(0.0) | ||
#define ONE_S8 svdup_n_s8(1) | ||
#define ONE_S16 svdup_n_s16(1) | ||
#define ONE_S32 svdup_n_s32(1) | ||
#define ONE_S64 svdup_n_s64(1) | ||
#define ONE_U8 svdup_n_u8(1) | ||
#define ONE_U16 svdup_n_u16(1) | ||
#define ONE_U32 svdup_n_u32(1) | ||
#define ONE_U64 svdup_n_u64(1) | ||
#define ONE_F16 svdup_n_f16(1.f) | ||
#define ONE_F32 svdup_n_f32(1.f) | ||
#define ONE_F64 svdup_n_f64(1.0) | ||
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) | ||
#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) | ||
#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) | ||
#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) | ||
#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) | ||
#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) | ||
#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) | ||
#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) | ||
#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) | ||
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) | ||
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) | ||
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) | ||
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) | ||
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) | ||
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) | ||
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) | ||
|
||
#endif // defined(CPU_CAPABILITY_SVE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
#pragma once | ||
|
||
// DO NOT DEFINE STATIC DATA IN THIS HEADER! | ||
// See Note [Do not compile initializers with SVE] | ||
|
||
#include <ATen/cpu/vec/intrinsics.h> | ||
|
||
#include <ATen/cpu/vec/vec_base.h> | ||
#include <ATen/cpu/vec/sve/sve_helper.h> | ||
|
||
#if defined(CPU_CAPABILITY_SVE) | ||
#include <ATen/cpu/vec/sve/vec_float.h> | ||
#include <ATen/cpu/vec/sve/vec_double.h> | ||
#include <ATen/cpu/vec/sve/vec_int.h> | ||
#include <ATen/cpu/vec/sve/vec_qint.h> | ||
#endif | ||
|
||
namespace at { | ||
namespace vec { | ||
// Note [CPU_CAPABILITY namespace] | ||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
// This header, and all of its subheaders, will be compiled with | ||
// different architecture flags for each supported set of vector | ||
// intrinsics. So we need to make sure they aren't inadvertently | ||
// linked together. We do this by declaring objects in an `inline | ||
// namespace` which changes the name mangling, but can still be | ||
// accessed as `at::vec`. | ||
inline namespace CPU_CAPABILITY { | ||
|
||
#if defined(CPU_CAPABILITY_SVE) | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template<> | ||
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) { | ||
return svreinterpret_f32_f64(src); | ||
} | ||
|
||
template<> | ||
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) { | ||
return svreinterpret_f64_f32(src); | ||
} | ||
|
||
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ | ||
template<> \ | ||
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \ | ||
return svreinterpret_s##int_bit##_f##float_bit(src); \ | ||
} \ | ||
template<> \ | ||
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \ | ||
return svreinterpret_f##float_bit##_s##int_bit(src); \ | ||
} | ||
|
||
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) | ||
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) | ||
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) | ||
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) | ||
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) | ||
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template<int64_t scale = 1> | ||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>> | ||
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) { | ||
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); | ||
return svld1_gather_s64index_f64(ptrue, base_addr, vindex); | ||
} | ||
|
||
template<int64_t scale = 1> | ||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>> | ||
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) { | ||
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); | ||
return svld1_gather_s32index_f32(ptrue, base_addr, vindex); | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template<int64_t scale = 1> | ||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>> | ||
inline mask_gather(const Vectorized<double>& src, const double* base_addr, | ||
const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) { | ||
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), | ||
ALL_S64_TRUE_MASK); | ||
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); | ||
return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); | ||
} | ||
|
||
template<int64_t scale = 1> | ||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>> | ||
inline mask_gather(const Vectorized<float>& src, const float* base_addr, | ||
const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) { | ||
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), | ||
ALL_S32_TRUE_MASK); | ||
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); | ||
return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
// Only works for inputs in the range: [-2^51, 2^51] | ||
// From: https://stackoverflow.com/a/41148578 | ||
template<> | ||
Vectorized<int64_t> | ||
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) { | ||
svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); | ||
return svsub_s64_x(ptrue, | ||
svreinterpret_s64_f64(x), | ||
svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); | ||
} | ||
|
||
template<> | ||
Vectorized<int32_t> | ||
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) { | ||
return svcvt_s32_f32_x(ptrue, src); | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template <> | ||
std::pair<Vectorized<double>, Vectorized<double>> | ||
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) { | ||
// inputs: | ||
// a = {a0, a1, a3, a3} | ||
// b = {b0, b1, b2, b3} | ||
// group cols crossing lanes: | ||
// return {a0, b0, a1, b1} | ||
// {a2, b2, a3, b3} | ||
return std::make_pair(Vectorized<double>(svzip1_f64(a, b)), | ||
Vectorized<double>(svzip2_f64(a, b))); | ||
} | ||
|
||
template <> | ||
std::pair<Vectorized<float>, Vectorized<float>> | ||
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) { | ||
// inputs: | ||
// a = {a0, a1, a2, a3, a4, a5, a6, a7} | ||
// b = {b0, b1, b2, b3, b4, b5, b6, b7} | ||
// group cols crossing lanes: | ||
// return {a0, b0, a1, b1, a2, b2, a3, b3} | ||
// {a4, b4, a5, b5, a6, b6, a7, b7} | ||
return std::make_pair(Vectorized<float>(svzip1_f32(a, b)), | ||
Vectorized<float>(svzip2_f32(a, b))); | ||
} | ||
|
||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
template <> | ||
std::pair<Vectorized<double>, Vectorized<double>> | ||
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) { | ||
// inputs: | ||
// a = {a0, b0, a1, b1} | ||
// b = {a2, b2, a3, b3} | ||
// swap lanes: | ||
// return {a0, a1, a2, a3} | ||
// {b0, b1, b2, b3} | ||
return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)), | ||
Vectorized<double>(svuzp2_f64(a, b))); | ||
} | ||
|
||
template <> | ||
std::pair<Vectorized<float>, Vectorized<float>> | ||
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) { | ||
// inputs: | ||
// a = {a0, b0, a1, b1, a2, b2, a3, b3} | ||
// b = {a4, b4, a5, b5, a6, b6, a7, b7} | ||
// swap lanes: | ||
// return {a0, a1, a2, a3, a4, a5, a6, a7} | ||
// {b0, b1, b2, b3, b4, b5, b6, b7} | ||
return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)), | ||
Vectorized<float>(svuzp2_f32(a, b))); | ||
} | ||
|
||
#endif // defined(CPU_CAPABILITY_SVE) | ||
|
||
}}} |
Oops, something went wrong.