Skip to content

Commit

Permalink
Merge pull request xtensor-stack#2367 from davidbrochart/specialize_c…
Browse files Browse the repository at this point in the history
…hunked_view

Specialize operator= when RHS is chunked
  • Loading branch information
JohanMabille authored Apr 22, 2021
2 parents 644970a + 9936e0e commit 28c6b1f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
9 changes: 9 additions & 0 deletions include/xtensor/xchunked_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ namespace xt
template<class E>
constexpr bool is_chunked(const xexpression<E>& e);

template<class E>
constexpr bool is_chunked();

/**
* Creates an in-memory chunked array.
* This function returns an uninitialized ``xchunked_array<xarray<T>>``.
Expand Down Expand Up @@ -286,6 +289,12 @@ namespace xt

template<class E>
constexpr bool is_chunked(const xexpression<E>&)
{
return is_chunked<E>();
}

template<class E>
constexpr bool is_chunked()
{
using return_type = typename detail::chunk_helper<E>::is_chunked;
return return_type::value;
Expand Down
75 changes: 71 additions & 4 deletions include/xtensor/xchunked_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
#include "xnoalias.hpp"
#include "xstorage.hpp"
#include "xstrided_view.hpp"
#include "xchunked_array.hpp"

namespace xt
{

template <class E>
struct is_chunked_t: detail::chunk_helper<E>::is_chunked
{
};

/*****************
* xchunked_view *
*****************/
Expand All @@ -30,7 +36,7 @@ namespace xt
class xchunked_view
{
public:

using self_type = xchunked_view<E>;
using expression_type = std::decay_t<E>;
using value_type = typename expression_type::value_type;
Expand All @@ -48,7 +54,15 @@ namespace xt
xchunked_view(OE&& e, S&& chunk_shape);

template <class OE>
xchunked_view<E>& operator=(const OE& e);
xchunked_view(OE&& e);

void init();

template <class OE>
typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);

template <class OE>
typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&> operator=(const OE& e);

size_type dimension() const noexcept;
const shape_type& shape() const noexcept;
Expand Down Expand Up @@ -92,6 +106,22 @@ namespace xt
m_shape.resize(e.dimension());
const auto& s = e.shape();
std::copy(s.cbegin(), s.cend(), m_shape.begin());
init();
}

template <class E>
template <class OE>
inline xchunked_view<E>::xchunked_view(OE&& e)
: m_expression(std::forward<OE>(e))
{
m_shape.resize(e.dimension());
const auto& s = e.shape();
std::copy(s.cbegin(), s.cend(), m_shape.begin());
}

template <class E>
void xchunked_view<E>::init()
{
// compute chunk number in each dimension
m_grid_shape.resize(m_shape.size());
std::transform
Expand All @@ -114,16 +144,47 @@ namespace xt

template <class E>
template <class OE>
xchunked_view<E>& xchunked_view<E>::operator=(const OE& e)
typename std::enable_if_t<!is_chunked_t<OE>::value, xchunked_view<E>&> xchunked_view<E>::operator=(const OE& e)
{
for (auto it = chunk_begin(); it != chunk_end(); it++)
auto end = chunk_end();
for (auto it = chunk_begin(); it != end; ++it)
{
auto el = *it;
noalias(el) = strided_view(e, it.get_slice_vector());
}
return *this;
}

template <class E>
template <class OE>
typename std::enable_if_t<is_chunked_t<OE>::value, xchunked_view<E>&> xchunked_view<E>::operator=(const OE& e)
{
m_chunk_shape.resize(e.dimension());
const auto& cs = e.chunk_shape();
std::copy(cs.cbegin(), cs.cend(), m_chunk_shape.begin());
init();
auto it2 = e.chunks().begin();
auto end1 = chunk_end();
for (auto it1 = chunk_begin(); it1 != end1; ++it1, ++it2)
{
auto el1 = *it1;
auto el2 = *it2;
auto lhs_shape = el1.shape();
if (lhs_shape != el2.shape())
{
xstrided_slice_vector esv(el2.dimension()); // element slice in edge chunk
std::transform(lhs_shape.begin(), lhs_shape.end(), esv.begin(),
[](auto size) { return range(0, size); });
noalias(el1) = strided_view(el2, esv);
}
else
{
noalias(el1) = el2;
}
}
return *this;
}

template <class E>
inline auto xchunked_view<E>::dimension() const noexcept -> size_type
{
Expand Down Expand Up @@ -209,6 +270,12 @@ namespace xt
{
return xchunked_view<E>(std::forward<E>(e), std::forward<S>(chunk_shape));
}

template <class E>
inline xchunked_view<E> as_chunked(E&& e)
{
return xchunked_view<E>(std::forward<E>(e));
}
}

#endif

0 comments on commit 28c6b1f

Please sign in to comment.