Skip to content

Commit

Permalink
add layout as template parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfv committed Apr 7, 2017
1 parent f9c9640 commit f0d75c3
Show file tree
Hide file tree
Showing 18 changed files with 663 additions and 573 deletions.
183 changes: 92 additions & 91 deletions include/xtensor/xarray.hpp

Large diffs are not rendered by default.

208 changes: 62 additions & 146 deletions include/xtensor/xcontainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <functional>
#include <numeric>
#include <stdexcept>

#include "xiterable.hpp"
#include "xiterator.hpp"
Expand Down Expand Up @@ -101,19 +102,6 @@ namespace xt
const inner_strides_type& strides() const noexcept;
const inner_backstrides_type& backstrides() const noexcept;

void transpose();

template <class S, class Tag = check_policy::none>
void transpose(S&& permutation, Tag check_policy = Tag());

#ifdef X_OLD_CLANG
template <class I, class Tag = check_policy::none>
void transpose(std::initializer_list<I> permutation, Tag check_policy = Tag());
#else
template <class I, std::size_t N, class Tag = check_policy::none>
void transpose(const I (&permutation)[N], Tag check_policy = Tag());
#endif

template <class... Args>
reference operator()(Args... args);

Expand Down Expand Up @@ -174,12 +162,6 @@ namespace xt

private:

template <class S>
void transpose_impl(S&& permutation, check_policy::none);

template <class S>
void transpose_impl(S&& permutation, check_policy::full);

inner_shape_type& mutable_shape();
inner_strides_type& mutable_strides();
inner_backstrides_type& mutable_backstrides();
Expand All @@ -200,7 +182,7 @@ namespace xt
* @tparam D The derived type, i.e. the inheriting class for which xstrided
* provides the partial imlpementation of xcontainer.
*/
template <class D>
template <class D, layout L>
class xstrided_container : public xcontainer<D>
{

Expand All @@ -220,10 +202,14 @@ namespace xt
using inner_strides_type = typename base_type::inner_strides_type;
using inner_backstrides_type = typename base_type::inner_backstrides_type;

void reshape(const shape_type& shape);
void reshape(const shape_type& shape, layout l);
static constexpr xt::layout layout_type = L;

void reshape(const shape_type& shape, bool force = false);
void reshape(const shape_type& shape, xt::layout l);
void reshape(const shape_type& shape, const strides_type& strides);

xt::layout layout();

protected:

xstrided_container() noexcept;
Expand All @@ -247,10 +233,10 @@ namespace xt
const inner_backstrides_type& backstrides_impl() const noexcept;

private:

inner_shape_type m_shape;
inner_strides_type m_strides;
inner_backstrides_type m_backstrides;
xt::layout m_layout = L;
};

/******************************
Expand Down Expand Up @@ -335,100 +321,6 @@ namespace xt
{
return derived_cast().backstrides_impl();
}

/**
* Transposes the container inplace by reversing the dimensions.
*/
template <class D>
inline void xcontainer<D>::transpose()
{
// reverse stride and shape
std::reverse(mutable_shape().begin(), mutable_shape().end());
std::reverse(mutable_strides().begin(), mutable_strides().end());
std::reverse(mutable_backstrides().begin(), mutable_backstrides().end());
}

/**
* Transposes the container inplace by permuting the shape with @p permutation.
* @param permutation the sequence containing permutation
* @param check_policy the check level (check_policy::full() or check_policy::none())
* @tparam Tag selects the level of error checking on permutation vector defaults to check_policy::none.
*/
template <class D>
template <class S, class Tag>
inline void xcontainer<D>::transpose(S&& permutation, Tag check_policy)
{
transpose_impl(std::forward<S>(permutation), check_policy);
}

#ifdef X_OLD_CLANG
template <class D>
template <class I, class Tag>
inline void xcontainer<D>::transpose(std::initializer_list<I> permutation, Tag check_policy)
{
std::vector<I> perm(permutation);
transpose_impl(std::move(perm), check_policy);
}
#else
template <class D>
template <class I, std::size_t N, class Tag>
inline void xcontainer<D>::transpose(const I (&permutation)[N], Tag check_policy)
{
transpose_impl(permutation, check_policy);
}
#endif

template <class D>
template <class S>
inline void xcontainer<D>::transpose_impl(S&& permutation, check_policy::full)
{
// check if axis appears twice in permutation
for (size_type i = 0; i < container_size(permutation); ++i)
{
for (size_type j = i + 1; j < container_size(permutation); ++j)
{
if (permutation[i] == permutation[j])
{
throw transpose_error("Permutation contains axis more than once");
}
}
}
transpose_impl(permutation, check_policy::none());
}

template <class D>
template <class S>
inline void xcontainer<D>::transpose_impl(S&& permutation, check_policy::none)
{
if (container_size(permutation) != dimension())
{
throw transpose_error("Permutation does not have the same size as shape");
}

// permute stride and shape
strides_type temp_strides;
resize_container(temp_strides, strides().size());

shape_type temp_shape;
resize_container(temp_shape, shape().size());

shape_type temp_backstrides;
resize_container(temp_backstrides, backstrides().size());

for (size_type i = 0; i < shape().size(); ++i)
{
if (size_type(permutation[i]) >= dimension())
{
throw transpose_error("Permutation contains wrong axis");
}
temp_shape[i] = shape()[permutation[i]];
temp_strides[i] = strides()[permutation[i]];
temp_backstrides[i] = backstrides()[permutation[i]];
}
mutable_shape() = std::move(temp_shape);
mutable_strides() = std::move(temp_strides);
mutable_backstrides() = std::move(temp_backstrides);
}
//@}


Expand Down Expand Up @@ -715,67 +607,86 @@ namespace xt
* xstrided_container implementation *
*************************************/

template <class D>
inline xstrided_container<D>::xstrided_container() noexcept
template <class D, layout L>
inline xstrided_container<D, L>::xstrided_container() noexcept
: base_type()
{
m_shape = make_sequence<inner_shape_type>(base_type::dimension(), 1);
}

template <class D>
inline xstrided_container<D>::xstrided_container(inner_shape_type&& shape, inner_strides_type&& strides) noexcept
template <class D, layout L>
inline xstrided_container<D, L>::xstrided_container(inner_shape_type&& shape, inner_strides_type&& strides) noexcept
: base_type(), m_shape(std::move(shape)), m_strides(std::move(strides))
{
m_backstrides = make_sequence<inner_backstrides_type>(m_shape.size(), 0);
adapt_strides(m_shape, m_strides, m_backstrides);
}

template <class D>
inline auto xstrided_container<D>::shape_impl() noexcept -> inner_shape_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::shape_impl() noexcept -> inner_shape_type&
{
return m_shape;
}

template <class D>
inline auto xstrided_container<D>::shape_impl() const noexcept -> const inner_shape_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::shape_impl() const noexcept -> const inner_shape_type&
{
return m_shape;
}

template <class D>
inline auto xstrided_container<D>::strides_impl() noexcept -> inner_strides_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::strides_impl() noexcept -> inner_strides_type&
{
return m_strides;
}

template <class D>
inline auto xstrided_container<D>::strides_impl() const noexcept -> const inner_strides_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::strides_impl() const noexcept -> const inner_strides_type&
{
return m_strides;
}

template <class D>
inline auto xstrided_container<D>::backstrides_impl() noexcept -> inner_backstrides_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::backstrides_impl() noexcept -> inner_backstrides_type&
{
return m_backstrides;
}

template <class D>
inline auto xstrided_container<D>::backstrides_impl() const noexcept -> const inner_backstrides_type&
template <class D, layout L>
inline auto xstrided_container<D, L>::backstrides_impl() const noexcept -> const inner_backstrides_type&
{
return m_backstrides;
}

/**
* Return the layout of the container
* @return layout of the container
*/
template <class D, layout L>
xt::layout xstrided_container<D, L>::layout()
{
return m_layout;
}

/**
* Reshapes the container.
* @param shape the new shape
* @param force force reshaping, even if the shape stays the same (default: false)
*/
template <class D>
inline void xstrided_container<D>::reshape(const shape_type& shape)
template <class D, layout L>
inline void xstrided_container<D, L>::reshape(const shape_type& shape, bool force)
{
if (shape != m_shape)
if (m_layout == xt::layout::dynamic)
{
reshape(shape, layout::row_major);
m_layout = xt::layout::row_major; // fall back to row major
}
if (shape != m_shape || force)
{
m_shape = shape;
resize_container(m_strides, m_shape.size());
resize_container(m_backstrides, m_shape.size());
size_type data_size = compute_strides(m_shape, m_layout, m_strides, m_backstrides);
this->data().resize(data_size);
}
}

Expand All @@ -784,24 +695,29 @@ namespace xt
* @param shape the new shape
* @param l the new layout
*/
template <class D>
inline void xstrided_container<D>::reshape(const shape_type& shape, layout l)
template <class D, layout L>
inline void xstrided_container<D, L>::reshape(const shape_type& shape, xt::layout l)
{
m_shape = shape;
resize_container(m_strides, m_shape.size());
resize_container(m_backstrides, m_shape.size());
size_type data_size = compute_strides(m_shape, l, m_strides, m_backstrides);
this->data().resize(data_size);
if (L != xt::layout::dynamic && l != L)
{
throw std::runtime_error("Cannot change layout if template parameter not layout::dynamic.");
}
m_layout = l;
reshape(shape, true);
}

/**
* Reshapes the container.
* @param shape the new shape
* @param strides the new strides
*/
template <class D>
inline void xstrided_container<D>::reshape(const shape_type& shape, const strides_type& strides)
template <class D, layout L>
inline void xstrided_container<D, L>::reshape(const shape_type& shape, const strides_type& strides)
{
if (L != xt::layout::dynamic)
{
throw std::runtime_error("Cannot reshape with custom strides when layout() is != layout::dynamic.");
}
m_shape = shape;
m_strides = strides;
resize_container(m_backstrides, m_strides.size());
Expand Down
6 changes: 1 addition & 5 deletions include/xtensor/xstrides.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@
#include <numeric>

#include "xexception.hpp"
#include "xtensor_forward.hpp"

namespace xt
{
enum class layout
{
row_major,
column_major
};

template <class shape_type>
typename shape_type::value_type compute_size(const shape_type& shape) noexcept;
Expand Down
Loading

0 comments on commit f0d75c3

Please sign in to comment.