Skip to content

Commit

Permalink
MAINT: Move AVX512 fp16 universal intrinsic to dispatch file
Browse files Browse the repository at this point in the history
  • Loading branch information
r-devulap committed Sep 26, 2022
1 parent 80f0015 commit a13006a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 39 deletions.
6 changes: 0 additions & 6 deletions numpy/core/src/common/simd/avx512/avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ typedef __m512i npyv_u32;
typedef __m512i npyv_s32;
typedef __m512i npyv_u64;
typedef __m512i npyv_s64;

#ifdef NPY_HAVE_AVX512_SKX
typedef __m256i npyvh_f16;
#endif

typedef __m512 npyv_f32;
typedef __m512d npyv_f64;

Expand Down Expand Up @@ -70,7 +65,6 @@ typedef struct { __m512d val[3]; } npyv_f64x3;
#define npyv_nlanes_s32 16
#define npyv_nlanes_u64 8
#define npyv_nlanes_s64 8
#define npyv_nlanes_f16 32
#define npyv_nlanes_f32 16
#define npyv_nlanes_f64 8

Expand Down
4 changes: 0 additions & 4 deletions numpy/core/src/common/simd/avx512/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
#endif
#define npyv_cvt_s32_b32 npyv_cvt_u32_b32
#define npyv_cvt_s64_b64 npyv_cvt_u64_b64

#define npyv_cvt_f16_f32 _mm512_cvtph_ps
#define npyv_cvt_f32_f16 _mm512_cvtps_ph

#define npyv_cvt_f32_b32(BL) _mm512_castsi512_ps(npyv_cvt_u32_b32(BL))
#define npyv_cvt_f64_b64(BL) _mm512_castsi512_pd(npyv_cvt_u64_b64(BL))

Expand Down
29 changes: 0 additions & 29 deletions numpy/core/src/common/simd/avx512/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ NPYV_IMPL_AVX512_MEM_INT(npy_uint64, u64)
NPYV_IMPL_AVX512_MEM_INT(npy_int64, s64)

// unaligned load

#ifdef NPY_HAVE_AVX512_SKX
#define npyvh_load_f16(PTR) _mm256_loadu_si256((const __m256i*)(PTR))
#endif

#define npyv_load_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
#define npyv_load_f64(PTR) _mm512_loadu_pd((const __m512d*)(PTR))
// aligned load
Expand All @@ -81,11 +76,6 @@ NPYV_IMPL_AVX512_MEM_INT(npy_int64, s64)
#define npyv_loads_f32(PTR) _mm512_castsi512_ps(npyv__loads(PTR))
#define npyv_loads_f64(PTR) _mm512_castsi512_pd(npyv__loads(PTR))
// unaligned store

#ifdef NPY_HAVE_AVX512_SKX
#define npyvh_store_f16(PTR, data) _mm256_storeu_si256((__m256i*)PTR, data)
#endif

#define npyv_store_f32 _mm512_storeu_ps
#define npyv_store_f64 _mm512_storeu_pd
// aligned store
Expand Down Expand Up @@ -164,17 +154,6 @@ NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a)
/*********************************
* Partial Load
*********************************/
// 16
#ifdef NPY_HAVE_AVX512_SKX
NPY_FINLINE npyvh_f16 npyvh_load_till_f16(const npy_half *ptr, npy_uintp nlane, npy_half fill)
{
assert(nlane > 0);
const __m256i vfill = _mm256_set1_epi16(fill);
const __mmask16 mask = (0x0001 << nlane) - 0x0001;
return _mm256_mask_loadu_epi16(vfill, mask, ptr);
}
#endif

//// 32
NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill)
{
Expand Down Expand Up @@ -246,14 +225,6 @@ npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane)
/*********************************
* Partial store
*********************************/
#ifdef NPY_HAVE_AVX512_SKX
NPY_FINLINE void npyvh_store_till_f16(npy_half *ptr, npy_uintp nlane, npyvh_f16 data)
{
assert(nlane > 0);
const __mmask16 mask = (0x0001 << nlane) - 0x0001;
_mm256_mask_storeu_epi16(ptr, mask, data);
}
#endif
//// 32
NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a)
{
Expand Down
19 changes: 19 additions & 0 deletions numpy/core/src/umath/loops_umath_fp.dispatch.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ simd_@func@_@sfx@(const npyv_lanetype_@sfx@ *src1, npy_intp ssrc1,
#endif

#if NPY_SIMD && defined(NPY_HAVE_AVX512_SKX) && defined(NPY_CAN_LINK_SVML)
typedef __m256i npyvh_f16;
#define npyv_cvt_f16_f32 _mm512_cvtph_ps
#define npyv_cvt_f32_f16 _mm512_cvtps_ph
#define npyvh_load_f16(PTR) _mm256_loadu_si256((const __m256i*)(PTR))
#define npyvh_store_f16(PTR, data) _mm256_storeu_si256((__m256i*)PTR, data)
NPY_FINLINE npyvh_f16 npyvh_load_till_f16(const npy_half *ptr, npy_uintp nlane, npy_half fill)
{
assert(nlane > 0);
const __m256i vfill = _mm256_set1_epi16(fill);
const __mmask16 mask = (0x0001 << nlane) - 0x0001;
return _mm256_mask_loadu_epi16(vfill, mask, ptr);
}
NPY_FINLINE void npyvh_store_till_f16(npy_half *ptr, npy_uintp nlane, npyvh_f16 data)
{
assert(nlane > 0);
const __mmask16 mask = (0x0001 << nlane) - 0x0001;
_mm256_mask_storeu_epi16(ptr, mask, data);
}

/**begin repeat
* #func = sin, cos, tan, exp, exp2, expm1, log, log2, log10, log1p, cbrt, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh#
* #default_val = 0, 0, 0, 0, 0, 0x3c00, 0x3c00, 0x3c00, 0x3c00, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3c00, 0#
Expand Down

0 comments on commit a13006a

Please sign in to comment.