Skip to content

Commit

Permalink
Specialize operator= when RHS is chunked
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 22, 2021
1 parent a1f6b16 commit e2ffd62
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions include/xtensor/xchunked_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
namespace xt
{

// SFINAE test if chunked
template <typename T>
class has_chunks
{
private:
typedef char YesType[1];
typedef char NoType[2];

template <typename C> static YesType& test(decltype(&C::chunk_shape));
template <typename C> static NoType& test(...);

public:
enum { value = sizeof(test<T>(0)) == sizeof(YesType) };
};

/*****************
* xchunked_view *
*****************/
Expand Down Expand Up @@ -48,7 +63,10 @@ namespace xt
xchunked_view(OE&& e, S&& chunk_shape);

template <class OE>
xchunked_view<E>& operator=(const OE& e);
typename std::enable_if<!has_chunks<OE>::value, xchunked_view<E>&>::type operator=(const OE& e);

template <class OE>
typename std::enable_if<has_chunks<OE>::value, xchunked_view<E>&>::type operator=(const OE& e);

size_type dimension() const noexcept;
const shape_type& shape() const noexcept;
Expand Down Expand Up @@ -114,7 +132,7 @@ namespace xt

template <class E>
template <class OE>
xchunked_view<E>& xchunked_view<E>::operator=(const OE& e)
typename std::enable_if<!has_chunks<OE>::value, xchunked_view<E>&>::type xchunked_view<E>::operator=(const OE& e)
{
for (auto it = chunk_begin(); it != chunk_end(); it++)
{
Expand All @@ -124,6 +142,30 @@ namespace xt
return *this;
}

template <class E>
template <class OE>
typename std::enable_if<has_chunks<OE>::value, xchunked_view<E>&>::type xchunked_view<E>::operator=(const OE& e)
{
for (auto it1 = chunk_begin(), it2 = e.chunks().begin(); it1 != chunk_end(); it1++, it2++)
{
auto el1 = *it1;
auto el2 = *it2;
auto lhs_shape = el1.shape();
if (lhs_shape != el2.shape())
{
xstrided_slice_vector esv(el2.dimension()); // element slice in edge chunk
std::transform(lhs_shape.begin(), lhs_shape.end(), esv.begin(),
[](auto size) { return range(0, size); });
noalias(el1) = strided_view(el2, esv);
}
else
{
noalias(el1) = el2;
}
}
return *this;
}

template <class E>
inline auto xchunked_view<E>::dimension() const noexcept -> size_type
{
Expand Down

0 comments on commit e2ffd62

Please sign in to comment.