Skip to content

Commit

Permalink
Merge pull request google#2249 from johnplatts:hwy_averageround_061824
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644636107
  • Loading branch information
copybara-github committed Jun 19, 2024
2 parents 4b70caa + 6d851a7 commit 79b2846
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 27 deletions.
4 changes: 2 additions & 2 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ from left to right, of the arguments passed to `Create{2-4}`.
<code>V **SaturatedSub**(V a, V b)</code> returns `a[i] - b[i]` saturated to
the minimum/maximum representable value.

* `V`: `{u}{8,16}` \
<code>V **AverageRound**(V a, V b)</code> returns `(a[i] + b[i] + 1) / 2`.
* `V`: `{u,i}` \
<code>V **AverageRound**(V a, V b)</code> returns `(a[i] + b[i] + 1) >> 1`.

* <code>V **Clamp**(V a, V lo, V hi)</code>: returns `a[i]` clamped to
`[lo[i], hi[i]]`.
Expand Down
10 changes: 8 additions & 2 deletions hwy/ops/arm_neon-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2074,8 +2074,14 @@ HWY_NEON_DEF_FUNCTION_INTS_UINTS(SaturatedSub, vqsub, _, 2)
// ------------------------------ Average

// Returns (a + b + 1) / 2
HWY_NEON_DEF_FUNCTION_UINT_8(AverageRound, vrhadd, _, 2)
HWY_NEON_DEF_FUNCTION_UINT_16(AverageRound, vrhadd, _, 2)

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

HWY_NEON_DEF_FUNCTION_UI_8_16_32(AverageRound, vrhadd, _, 2)

// ------------------------------ Neg

Expand Down
22 changes: 17 additions & 5 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4702,13 +4702,25 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {

// ------------------------------ AverageRound (ShiftRight)

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_U08(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
HWY_SVE_FOREACH_U16(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, AverageRound, rhadd)
#else
template <class V>
V AverageRound(const V a, const V b) {
return ShiftRight<1>(detail::AddN(Add(a, b), 1));
template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
HWY_API V AverageRound(const V a, const V b) {
return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)),
detail::AndN(Or(a, b), 1));
}
#endif // HWY_SVE_HAVE_2

Expand Down
21 changes: 18 additions & 3 deletions hwy/ops/emu128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,26 @@ HWY_API Vec128<T, N> SaturatedSub(Vec128<T, N> a, Vec128<T, N> b) {
}

// ------------------------------ AverageRound
template <typename T, size_t N>

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

template <typename T, size_t N, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec128<T, N> AverageRound(Vec128<T, N> a, Vec128<T, N> b) {
static_assert(!IsSigned<T>(), "Only for unsigned");
for (size_t i = 0; i < N; ++i) {
a.raw[i] = static_cast<T>((a.raw[i] + b.raw[i] + 1) / 2);
const T a_val = a.raw[i];
const T b_val = b.raw[i];
a.raw[i] = static_cast<T>(ScalarShr(a_val, 1) + ScalarShr(b_val, 1) +
((a_val | b_val) & 1));
}
return a;
}
Expand Down
38 changes: 38 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4655,6 +4655,44 @@ HWY_API Vec512<T> operator%(Vec512<T> a, Vec512<T> b) {

#endif // HWY_NATIVE_INT_DIV

// ------------------------------ AverageRound

#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI32) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

template <class V, HWY_IF_UI32(TFromV<V>)>
HWY_API V AverageRound(V a, V b) {
using T = TFromV<V>;
const DFromV<decltype(a)> d;
return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)),
And(Or(a, b), Set(d, T{1})));
}

#endif // HWY_NATIVE_AVERAGE_ROUND_UI64

#if (defined(HWY_NATIVE_AVERAGE_ROUND_UI64) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

#if HWY_HAVE_INTEGER64
template <class V, HWY_IF_UI64(TFromV<V>)>
HWY_API V AverageRound(V a, V b) {
using T = TFromV<V>;
const DFromV<decltype(a)> d;
return Add(Add(ShiftRight<1>(a), ShiftRight<1>(b)),
And(Or(a, b), Set(d, T{1})));
}
#endif

#endif // HWY_NATIVE_AVERAGE_ROUND_UI64

// ------------------------------ MulEvenAdd (PromoteEvenTo)

// SVE with bf16 and NEON with bf16 override this.
Expand Down
25 changes: 23 additions & 2 deletions hwy/ops/ppc_vsx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1554,12 +1554,33 @@ HWY_API V SaturatedSub(V a, V b) {

// Returns (a + b + 1) / 2

template <typename T, size_t N, HWY_IF_UNSIGNED(T),
HWY_IF_T_SIZE_ONE_OF(T, 0x6)>
#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

#if HWY_S390X_HAVE_Z14
#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

#define HWY_PPC_IF_AVERAGE_ROUND_T(T) void* = nullptr
#else // !HWY_S390X_HAVE_Z14
#define HWY_PPC_IF_AVERAGE_ROUND_T(T) \
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))
#endif // HWY_S390X_HAVE_Z14

template <typename T, size_t N, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_PPC_IF_AVERAGE_ROUND_T(T)>
HWY_API Vec128<T, N> AverageRound(Vec128<T, N> a, Vec128<T, N> b) {
return Vec128<T, N>{vec_avg(a.raw, b.raw)};
}

#undef HWY_PPC_IF_AVERAGE_ROUND_T

// ------------------------------ Multiplication

// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*.
Expand Down
16 changes: 14 additions & 2 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,18 @@ HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL)

// ------------------------------ AverageRound

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

// Define this to opt-out of the default behavior, which is AVOID on certain
// compiler versions. You can define only this to use VXRM, or define both this
// and HWY_RVV_AVOID_VXRM to always avoid VXRM.
Expand Down Expand Up @@ -1153,8 +1165,8 @@ HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL)
a, b, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RNU, HWY_RVV_AVL(SEW, SHIFT))); \
}

HWY_RVV_FOREACH_U08(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL)
HWY_RVV_FOREACH_U16(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL)
HWY_RVV_FOREACH_I(HWY_RVV_RETV_AVERAGE, AverageRound, aadd, _ALL)
HWY_RVV_FOREACH_U(HWY_RVV_RETV_AVERAGE, AverageRound, aaddu, _ALL)

#undef HWY_RVV_RETV_AVERAGE

Expand Down
25 changes: 18 additions & 7 deletions hwy/ops/scalar-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,24 @@ HWY_API Vec1<int16_t> SaturatedSub(const Vec1<int16_t> a,

// Returns (a + b + 1) / 2

HWY_API Vec1<uint8_t> AverageRound(const Vec1<uint8_t> a,
const Vec1<uint8_t> b) {
return Vec1<uint8_t>(static_cast<uint8_t>((a.raw + b.raw + 1) / 2));
}
HWY_API Vec1<uint16_t> AverageRound(const Vec1<uint16_t> a,
const Vec1<uint16_t> b) {
return Vec1<uint16_t>(static_cast<uint16_t>((a.raw + b.raw + 1) / 2));
#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
#undef HWY_NATIVE_AVERAGE_ROUND_UI32
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI32
#endif

#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
#undef HWY_NATIVE_AVERAGE_ROUND_UI64
#else
#define HWY_NATIVE_AVERAGE_ROUND_UI64
#endif

template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec1<T> AverageRound(const Vec1<T> a, const Vec1<T> b) {
const T a_val = a.raw;
const T b_val = b.raw;
return Vec1<T>(static_cast<T>(ScalarShr(a_val, 1) + ScalarShr(b_val, 1) +
((a_val | b_val) & 1)));
}

// ------------------------------ Absolute value
Expand Down
11 changes: 11 additions & 0 deletions hwy/ops/wasm_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,17 @@ HWY_API Vec128<uint16_t, N> AverageRound(const Vec128<uint16_t, N> a,
return Vec128<uint16_t, N>{wasm_u16x8_avgr(a.raw, b.raw)};
}

template <class V, HWY_IF_SIGNED_V(V),
HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))>
HWY_API V AverageRound(V a, V b) {
const DFromV<decltype(a)> d;
const RebindToUnsigned<decltype(d)> du;
const V sign_bit = SignBit(d);
return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)),
BitCast(du, Xor(b, sign_bit)))),
sign_bit);
}

// ------------------------------ Absolute value

// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1.
Expand Down
3 changes: 2 additions & 1 deletion hwy/ops/wasm_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ HWY_API Vec256<T> SaturatedSub(Vec256<T> a, const Vec256<T> b) {
return a;
}

template <typename T>
template <typename T, HWY_IF_UNSIGNED(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))>
HWY_API Vec256<T> AverageRound(Vec256<T> a, const Vec256<T> b) {
a.v0 = AverageRound(a.v0, b.v0);
a.v1 = AverageRound(a.v1, b.v1);
Expand Down
12 changes: 12 additions & 0 deletions hwy/ops/x86_128-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4226,6 +4226,18 @@ HWY_API Vec128<uint16_t, N> AverageRound(const Vec128<uint16_t, N> a,
return Vec128<uint16_t, N>{_mm_avg_epu16(a.raw, b.raw)};
}

// I8/I16 AverageRound is generic for all vector lengths
template <class V, HWY_IF_SIGNED_V(V),
HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2))>
HWY_API V AverageRound(V a, V b) {
const DFromV<decltype(a)> d;
const RebindToUnsigned<decltype(d)> du;
const V sign_bit = SignBit(d);
return Xor(BitCast(d, AverageRound(BitCast(du, Xor(a, sign_bit)),
BitCast(du, Xor(b, sign_bit)))),
sign_bit);
}

// ------------------------------ Integer multiplication

template <size_t N>
Expand Down
40 changes: 37 additions & 3 deletions hwy/tests/arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,56 @@ HWY_NOINLINE void TestAllAddSub() {
struct TestAverage {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
using TI = MakeSigned<T>;

const RebindToSigned<decltype(d)> di;
const RebindToUnsigned<decltype(d)> du;

const Vec<D> v0 = Zero(d);
const Vec<D> v1 = Set(d, static_cast<T>(1));
const Vec<D> v2 = Set(d, static_cast<T>(2));

const Vec<D> vn1 = Set(d, static_cast<T>(-1));
const Vec<D> vn2 = Set(d, static_cast<T>(-2));
const Vec<D> vn3 = Set(d, static_cast<T>(-3));
const Vec<D> vn4 = Set(d, static_cast<T>(-4));

HWY_ASSERT_VEC_EQ(d, v0, AverageRound(v0, v0));
HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v0, v1));
HWY_ASSERT_VEC_EQ(d, v1, AverageRound(v1, v1));
HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v1, v2));
HWY_ASSERT_VEC_EQ(d, v2, AverageRound(v2, v2));

HWY_ASSERT_VEC_EQ(d, vn1, AverageRound(vn1, vn1));
HWY_ASSERT_VEC_EQ(d, vn1, AverageRound(vn1, vn2));
HWY_ASSERT_VEC_EQ(d, vn2, AverageRound(vn1, vn3));
HWY_ASSERT_VEC_EQ(d, vn2, AverageRound(vn1, vn4));
HWY_ASSERT_VEC_EQ(d, vn2, AverageRound(vn2, vn2));
HWY_ASSERT_VEC_EQ(d, vn3, AverageRound(vn2, vn4));

const T kSignedMax = static_cast<T>(LimitsMax<TI>());

const Vec<D> v_iota1 = Iota(d, static_cast<T>(1));
Vec<D> v_neg_even = BitCast(d, Neg(BitCast(di, Add(v_iota1, v_iota1))));
HWY_IF_CONSTEXPR(HWY_MAX_LANES_D(D) > static_cast<size_t>(kSignedMax)) {
v_neg_even = Or(v_neg_even, SignBit(d));
}

const Vec<D> v_pos_even = And(v_neg_even, Set(d, kSignedMax));
const Vec<D> v_pos_odd = Or(v_pos_even, v1);

const Vec<D> expected_even =
Add(ShiftRight<1>(v_neg_even),
BitCast(d, ShiftRight<1>(BitCast(du, v_pos_even))));

HWY_ASSERT_VEC_EQ(d, expected_even, AverageRound(v_neg_even, v_pos_even));
HWY_ASSERT_VEC_EQ(d, Add(expected_even, v1),
AverageRound(v_neg_even, v_pos_odd));
}
};

HWY_NOINLINE void TestAllAverage() {
const ForPartialVectors<TestAverage> test;
test(uint8_t());
test(uint16_t());
ForIntegerTypes(ForPartialVectors<TestAverage>());
}

struct TestAbs {
Expand Down

0 comments on commit 79b2846

Please sign in to comment.