Skip to content

Commit

Permalink
* Matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubpa committed May 26, 2020
1 parent 88f5676 commit a7fb8e7
Show file tree
Hide file tree
Showing 9 changed files with 280 additions and 128 deletions.
126 changes: 67 additions & 59 deletions include/UGM/Interfaces/IArray/IArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,70 @@
#include <cassert>

namespace Ubpa {
template<bool SIMD, typename Base, typename Impl>
struct IArrayImpl;
template<bool SIMD, typename IArray_Base, typename Impl>
struct IArray_Impl;

template<typename Base, typename Impl>
struct IArray : IArrayImpl<SupportSIMD_v<Impl>, Base, Impl> {
using IArrayImpl<SupportSIMD_v<Impl>, Base, Impl>::IArrayImpl;
struct IArray : IArray_Impl<SupportSIMD_v<Impl>, Base, Impl> {
using IArray_Impl<SupportSIMD_v<Impl>, Base, Impl>::IArray_Impl;
};

template<typename Base, typename Impl>
struct IArrayImpl<false, Base, Impl> : Base, std::array<ImplTraits_T<Impl>, ImplTraits_N<Impl>> {
template<typename IArray_Base, typename Impl>
struct IArray_Impl<false, IArray_Base, Impl> : IArray_Base, std::array<ImplTraits_T<Impl>, ImplTraits_N<Impl>> {
using T = ImplTraits_T<Impl>;
using F = ImplTraits_F<Impl>;
static constexpr size_t N = ImplTraits_N<Impl>;

private:
using Base::operator[];
using IArray_Base::operator[];
public:
using std::array<T, N>::operator[];

static_assert(N > 0);

using Base::Base;
using IArray_Base::IArray_Base;
using std::array<T, N>::array;

IArrayImpl() noexcept {};
IArray_Impl() noexcept {};

template<size_t... Ns>
IArrayImpl(const IArrayImpl& arr, std::index_sequence<Ns...>) noexcept {
IArray_Impl(const T* arr, std::index_sequence<Ns...>) noexcept {
(((*this)[Ns] = arr[Ns]), ...);
}
IArray_Impl(const T* arr) noexcept : IArray_Impl{ arr, std::make_index_sequence<N>{} } {}

template<size_t... Ns>
IArray_Impl(const IArray_Impl& arr, std::index_sequence<Ns...>) noexcept {
(((*this)[Ns] = arr[Ns]), ...);
};
IArrayImpl(const IArrayImpl& arr) noexcept : IArrayImpl{ arr, std::make_index_sequence<N>{} } {}
IArray_Impl(const IArray_Impl& arr) noexcept : IArray_Impl{ arr, std::make_index_sequence<N>{} } {}

template<size_t... Ns>
IArrayImpl(T t, std::index_sequence<Ns...>) noexcept {
IArray_Impl(T t, std::index_sequence<Ns...>) noexcept {
(((*this)[Ns] = t), ...);
};
constexpr IArrayImpl(T t) noexcept : IArrayImpl{ t, std::make_index_sequence<N>{} } {}
constexpr IArray_Impl(T t) noexcept : IArray_Impl{ t, std::make_index_sequence<N>{} } {}

template<typename... Us, std::enable_if_t<sizeof...(Us) == N>* = nullptr>
constexpr IArray_Impl(Us... vals) noexcept : std::array<T, N>{static_cast<T>(vals)...} {}

template<size_t i>
T get() const noexcept {
static_assert(i < N);
return (*this)[i];
}

template<typename... U, typename = std::enable_if_t<(std::is_convertible_v<U, T>&&...)>>
constexpr IArrayImpl(U... data) noexcept : std::array<T, N>{static_cast<T>(data)...} {
static_assert(sizeof...(U) == N, "number of parameters is not correct");
template<size_t i>
void set(T v) const noexcept {
static_assert(i < N);
(*this)[i] = v;
}
};

#ifdef UBPA_USE_SIMD
// alignas(16)
template<typename Base, typename Impl>
struct alignas(16) IArrayImpl<true, Base, Impl> : Base
template<typename IArray_Base, typename Impl>
struct alignas(16) IArray_Impl<true, IArray_Base, Impl> : IArray_Base
{
public:
__m128 m;
Expand All @@ -69,7 +86,7 @@ namespace Ubpa {

// array interface
std::array<float, 4>& to_array() noexcept { return reinterpret_cast<std::array<float, 4>&>(*this); }
const std::array<float, 4>& to_array() const noexcept { return const_cast<IArrayImpl*>(this)->to_array(); }
const std::array<float, 4>& to_array() const noexcept { return const_cast<IArray_Impl*>(this)->to_array(); }

using value_type = float;
using size_type = size_t;
Expand Down Expand Up @@ -123,61 +140,52 @@ namespace Ubpa {
reference back() noexcept { return to_array().back(); }
const_reference back() const noexcept { return to_array().back(); }

float x() const noexcept { return _mm_cvtss_f32(m); }
float y() const noexcept { return _mm_cvtss_f32(_mm_shuffle_ps(m, m, _MM_SHUFFLE(1, 1, 1, 1))); }
float z() const noexcept { return _mm_cvtss_f32(_mm_shuffle_ps(m, m, _MM_SHUFFLE(2, 2, 2, 2))); }
float w() const noexcept { return _mm_cvtss_f32(_mm_shuffle_ps(m, m, _MM_SHUFFLE(3, 3, 3, 3))); }
void set_x(float v) noexcept { m = _mm_move_ss(m, _mm_set_ss(v)); }
void set_y(float v) noexcept {
__m128 t = _mm_move_ss(m, _mm_set_ss(v));
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(3, 2, 0, 0));
m = _mm_move_ss(t, m);
}
void set_z(float v) noexcept {
__m128 t = _mm_move_ss(m, _mm_set_ss(v));
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(3, 0, 1, 0));
m = _mm_move_ss(t, m);
}
void set_w(float v) noexcept {
__m128 t = _mm_move_ss(m, _mm_set_ss(v));
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(0, 2, 1, 0));
m = _mm_move_ss(t, m);
}
template<size_t i>
float at() const noexcept {
float get() const noexcept {
static_assert(i < 4);
if constexpr (i == 0)
return x();
return _mm_cvtss_f32(m);
else
return _mm_cvtss_f32(_mm_shuffle_ps(m, m, _MM_SHUFFLE(i, i, i, i)));
}

template<size_t i>
void set(float v) const noexcept {
static_assert(i < 4);
if constexpr (i == 0)
return set_x(v);
else if constexpr (i == 1)
return set_y(v);
else if constexpr (i == 2)
return set_z(v);
else if constexpr (i == 3)
return set_w(v);
if constexpr (i == 0) {
m = _mm_move_ss(m, _mm_set_ss(v));
}
else {
__m128 t = _mm_move_ss(m, _mm_set_ss(v));

if constexpr (i == 1)
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(3, 2, 0, 0));
else if constexpr (i == 2)
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(3, 0, 1, 0));
else // if constexpr (i == 3)
t = _mm_shuffle_ps(t, t, _MM_SHUFFLE(0, 2, 1, 0));

m = _mm_move_ss(t, m);
}
}

float* data() noexcept { return reinterpret_cast<float*>(this); }
const float* data() const noexcept { return const_cast<IArrayImpl*>(this)->data(); }
const float* data() const noexcept { return const_cast<IArray_Impl*>(this)->data(); }

// ==================

IArrayImpl() noexcept {}
IArrayImpl(__m128 f4) noexcept : m{ f4 } {}
explicit IArrayImpl(float v) noexcept : m{ _mm_set1_ps(v) } {}
IArrayImpl(const IArrayImpl& f4) noexcept : m{ f4.m } {}
IArrayImpl& operator=(const IArrayImpl& f4) noexcept { m = f4.m; return *this; }
IArray_Impl() noexcept {}
IArray_Impl(const float* f4) noexcept : m{ _mm_loadu_ps(f4) } {}
IArray_Impl(__m128 f4) noexcept : m{ f4 } {} // align
explicit IArray_Impl(float v) noexcept : m{ _mm_set1_ps(v) } {}
IArray_Impl(const IArray_Impl& f4) noexcept : m{ f4.m } {}
IArray_Impl& operator=(const IArray_Impl& f4) noexcept { m = f4.m; return *this; }
operator __m128& () noexcept { return m; }
operator const __m128&() const noexcept { return const_cast<IArrayImpl*>(this)->operator __m128 &(); }
IArrayImpl(float x, float y, float z, float w) noexcept : m{ _mm_set_ps(w, z, y, x) } {} // little-endianness
operator const __m128&() const noexcept { return const_cast<IArray_Impl*>(this)->operator __m128 &(); }
IArray_Impl(float x, float y, float z, float w) noexcept : m{ _mm_set_ps(w, z, y, x) } {} // little-endianness
template<typename Ux, typename Uy, typename Uz, typename Uw>
IArrayImpl(Ux x, Uy y, Uz z, Uw w) noexcept
: IArrayImpl{ static_cast<float>(x),static_cast<float>(y),static_cast<float>(z),static_cast<float>(w) } {}
IArray_Impl(Ux x, Uy y, Uz z, Uw w) noexcept
: IArray_Impl{ static_cast<float>(x),static_cast<float>(y),static_cast<float>(z),static_cast<float>(w) } {}
};
#endif
}
27 changes: 19 additions & 8 deletions include/UGM/Interfaces/IArray/IArrayUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace Ubpa {
F one_minus_t = static_cast<F>(1) - t;
#ifdef UBPA_USE_SIMD
if constexpr (SupportSIMD_v<Impl>)
return one_minus_t * x + t * y;
return _mm_add_ps(_mm_mul_ps(x, _mm_set1_ps(one_minus_t)), _mm_mul_ps(y, _mm_set1_ps(t)));
else
#endif // UBPA_USE_SIMD
{
Expand All @@ -59,24 +59,35 @@ namespace Ubpa {
return lerp(x, y, static_cast<F>(0.5));
}

static const Impl mix(const std::vector<Impl>& vals, const std::vector<float>& weights) noexcept {
template<typename ValContainer, typename WeightContainer>
static const Impl mix(const ValContainer& vals, const WeightContainer& weights) noexcept {
assert(vals.size() > 0 && vals.size() == weights.size());
auto val_iter = vals.begin();
auto weight_iter = weights.begin();

#ifdef UBPA_USE_SIMD
if constexpr (SupportSIMD_v<Impl>) {
auto rst = vals[0].get() * weights[0];
for (size_t i = 1; i < vals.size(); i++)
rst += vals[i].get() * weights[i];
__m128 rst = _mm_mul_ps(*val_iter, *weight_iter);
++val_iter;
++weight_iter;
while (val_iter != vals.end()) {
rst = _mm_add_ps(rst, _mm_mul_ps(*val_iter, *weight_iter));
++val_iter;
++weight_iter;
}
return rst;
}
else
#endif // UBPA_USE_SIMD
{
Impl rst;
for (size_t j = 0; j < N; j++)
rst[j] = vals[0][j] * weights[0];
for (size_t i = 1; i < vals.size(); i++) {
rst[j] = (*val_iter)[j] * (*weight_iter);
while (val_iter != vals.end()) {
for (size_t j = 0; j < N; j++)
rst[j] += vals[i][j] * weights[i];
rst[j] += (*val_iter)[j] * (*weight_iter);
++val_iter;
++weight_iter;
}
return rst;
}
Expand Down
60 changes: 43 additions & 17 deletions include/UGM/Interfaces/IMatrix/IMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,51 @@ namespace Ubpa {
static_assert(N == 3 || N == 4);
static_assert(Vector::N == N);

inline IMatrix(const std::array<F,N*N>& data) noexcept { init(data); }
// column first
inline IMatrix(const std::array<F, N * N>& data) noexcept { init(data); }

inline void init(const std::array<F, N * N>& data) noexcept {
// unloop in /O2
for (size_t i = 0; i < N * N; i++)
(*this)(i) = data[i];
// column first
inline void init(const std::array<F, N* N>& data) noexcept {
auto& m = static_cast<Impl&>(*this);
detail::IMatrix_::init<N>::run(m, data);
}

// row first
// -o-o-o-o-o-o-
// U0 is used to avoid MSVC's compiler bug
// because IArrayImpl has same constructor:
// - template<typename... Us, nullptr> IArrayImpl(Us... vals);
template<typename U0, typename... Us, std::enable_if_t<1 + sizeof...(Us) == N * N>* = nullptr>
constexpr IMatrix(U0 val0, Us... vals) noexcept { init(val0, vals...); }

// row first
template<typename... Us, std::enable_if_t<sizeof...(Us) == N * N>* = nullptr>
inline void init(Us... vals) noexcept {
auto t = std::make_tuple(static_cast<F>(vals)...);
if constexpr (N == 3) {
init(std::array<F, 3 * 3>{
std::get<0>(t), std::get<3>(t), std::get<6>(t),
std::get<1>(t), std::get<4>(t), std::get<7>(t),
std::get<2>(t), std::get<5>(t), std::get<8>(t),
});
}
else // if constexpr (N == 4)
{
init(std::array<F, 4 * 4>{
std::get< 0>(t), std::get< 4>(t), std::get< 8>(t), std::get<12>(t),
std::get< 1>(t), std::get< 5>(t), std::get< 9>(t), std::get<13>(t),
std::get< 2>(t), std::get< 6>(t), std::get<10>(t), std::get<14>(t),
std::get< 3>(t), std::get< 7>(t), std::get<11>(t), std::get<15>(t),
});
}
}

inline static const Impl eye() noexcept {
return detail::IMatrix::eye<Impl, N>::run();
return detail::IMatrix_::eye<Impl, N>::run();
}

inline static const Impl zero() noexcept {
Impl rst;
for (size_t i = 0; i < N * N; i++)
(*this)(i) = static_cast<F>(0);
return rst;
return detail::IMatrix_::zero<N>::template run<Impl>();
}

inline F& operator()(size_t r, size_t c) noexcept {
Expand Down Expand Up @@ -79,23 +107,21 @@ namespace Ubpa {
}

inline F trace() const noexcept {
F rst = (*this)[0][0];
for (size_t i = 1; i < N; i++)
rst += (*this)[i][i];
return rst;
const auto& m = static_cast<const Impl&>(*this);
return detail::IMatrix_::trace<N>::run(m);
}

inline const Impl transpose() const noexcept {
const auto& m = static_cast<const Impl&>(*this);
return detail::IMatrix::transpose<N>::run(m);
return detail::IMatrix_::transpose<N>::run(m);
}

F* data() noexcept {
return &(*this)[0][0];
return reinterpret_cast<F*>(this);
}

const F* data() const noexcept {
return &(*this)[0][0];
return const_cast<IMatrix*>(this)->data();
}
};

Expand Down
2 changes: 1 addition & 1 deletion include/UGM/Interfaces/IMatrix/IMatrixInOut.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ namespace Ubpa {
};

InterfaceTraits_Regist(IMatrixInOut,
IArrayInOut, IMatrix);
IMatrix);
}
10 changes: 5 additions & 5 deletions include/UGM/Interfaces/IMatrix/IMatrixMul_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ namespace Ubpa::detail::IMatrixMul {
F f21 = x(2, 0) * y(0, 1) + x(2, 1) * y(1, 1) + x(2, 2) * y(2, 1);
F f22 = x(2, 0) * y(0, 2) + x(2, 1) * y(1, 2) + x(2, 2) * y(2, 2);

return std::array<F, 3 * 3>{
return {
f00, f01, f02,
f10, f11, f12,
f20, f21, f22
Expand Down Expand Up @@ -473,11 +473,11 @@ namespace Ubpa::detail::IMatrixMul {
F f32 = x(3, 0) * y(0, 2) + x(3, 1) * y(1, 2) + x(3, 2) * y(2, 2) + x(3, 3) * y(3, 2);
F f33 = x(3, 0) * y(0, 3) + x(3, 1) * y(1, 3) + x(3, 2) * y(2, 3) + x(3, 3) * y(3, 3);

return std::array<F, 4 * 4>{
return {
f00, f01, f02, f03,
f10, f11, f12, f13,
f20, f21, f22, f23,
f30, f31, f32, f33
f10, f11, f12, f13,
f20, f21, f22, f23,
f30, f31, f32, f33
};
}
}
Expand Down
Loading

0 comments on commit a7fb8e7

Please sign in to comment.