Skip to content

Commit

Permalink
Merge pull request xtensor-stack#2356 from tdegeus/flat
Browse files Browse the repository at this point in the history
Adding .flat(i)
  • Loading branch information
JohanMabille authored May 18, 2021
2 parents a8f1a6a + 50e3d42 commit b5c9286
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 9 deletions.
3 changes: 1 addition & 2 deletions docs/source/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ See :any:`numpy indexing <numpy:arrays.indexing>` page.
+====================================================================+====================================================================+
| ``a[3, 2]`` | ``a(3, 2)`` |
+--------------------------------------------------------------------+--------------------------------------------------------------------+
| :any:`a.flat[4] <numpy.ndarray.flat>` || ``a[4]`` |
| || ``a(4)`` |
| :any:`a.flat[4] <numpy.ndarray.flat>` | ``a.flat(4)`` |
+--------------------------------------------------------------------+--------------------------------------------------------------------+
| ``a[3]`` || ``xt::view(a, 3, xt::all())`` |
| || ``xt::row(a, 3)`` |
Expand Down
7 changes: 3 additions & 4 deletions include/xtensor/xaccessible.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace xt
/************************************
* xconst_accessible implementation *
************************************/

/**
* Returns the size of the expression.
*/
Expand Down Expand Up @@ -211,7 +211,7 @@ namespace xt
normalize_periodic(derived_cast().shape(), args...);
return derived_cast()(static_cast<size_type>(args)...);
}

/**
* Returns ``true`` only if the the specified position is a valid entry in the expression.
* @param args a list of indices specifying the position in the expression.
Expand Down Expand Up @@ -246,7 +246,7 @@ namespace xt
template <class D>
template <class... Args>
inline auto xaccessible<D>::at(Args... args) -> reference
{
{
check_access(derived_cast().shape(), static_cast<size_type>(args)...);
return derived_cast().operator()(args...);
}
Expand Down Expand Up @@ -293,7 +293,6 @@ namespace xt
return derived_cast()(static_cast<size_type>(args)...);
}


template <class D>
inline auto xaccessible<D>::derived_cast() noexcept -> derived_type&
{
Expand Down
15 changes: 15 additions & 0 deletions include/xtensor/xcontainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ namespace xt
reference data_element(size_type i);
const_reference data_element(size_type i) const;

reference flat(size_type i);
const_reference flat(size_type i) const;

template <class requested_type>
using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;

Expand Down Expand Up @@ -643,6 +646,18 @@ namespace xt
return storage()[i];
}

template <class D>
inline auto xcontainer<D>::flat(size_type i) -> reference
{
return storage()[i];
}

template <class D>
inline auto xcontainer<D>::flat(size_type i) const -> const_reference
{
return storage()[i];
}

/***************
* stepper api *
***************/
Expand Down
15 changes: 15 additions & 0 deletions include/xtensor/xdynamic_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ namespace xt
template <class... Args>
const_reference unchecked(Args... args) const;

reference flat(size_type index);
const_reference flat(size_type index) const;

using base_type::operator[];
using base_type::at;
using base_type::periodic;
Expand Down Expand Up @@ -450,6 +453,18 @@ namespace xt
return base_type::storage()[static_cast<size_type>(offset)];
}

template <class CT, class S, layout_type L, class FST>
inline auto xdynamic_view<CT, S, L, FST>::flat(size_type i) -> reference
{
return base_type::storage()[data_offset() + i];
}

template <class CT, class S, layout_type L, class FST>
inline auto xdynamic_view<CT, S, L, FST>::flat(size_type i) const -> const_reference
{
return base_type::storage()[data_offset() + i];
}

template <class CT, class S, layout_type L, class FST>
template <class It>
inline auto xdynamic_view<CT, S, L, FST>::element(It first, It last) -> reference
Expand Down
16 changes: 16 additions & 0 deletions include/xtensor/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ namespace xt

const_reference data_element(size_type i) const;

const_reference flat(size_type i) const;

template <class UT = self_type, class = typename std::enable_if<UT::only_scalar::value>::type>
operator value_type() const;

Expand Down Expand Up @@ -583,6 +585,20 @@ namespace xt
return access_impl(std::make_index_sequence<sizeof...(CT)>(), static_cast<size_type>(args)...);
}

/**
* @name Data
*/
/**
* Returns a constant reference to the element at the specified position of the underlying
* contiguous storage of the function.
* @param index index to underlying flat storage.
*/
template <class F, class... CT>
inline auto xfunction<F, CT...>::flat(size_type index) const -> const_reference
{
return data_element_impl(std::make_index_sequence<sizeof...(CT)>(), index);
}

/**
* Returns a constant reference to the element at the specified position in the expression.
* @param args a list of indices specifying the position in the expression. Indices
Expand Down
19 changes: 17 additions & 2 deletions include/xtensor/xfunctor_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ namespace xt
using accessible_base::at;
using accessible_base::operator[];
using accessible_base::periodic;

using accessible_base::in_bounds;

xexpression_type& expression() noexcept;
Expand All @@ -203,6 +204,20 @@ namespace xt
return m_functor(m_e.data_element(i));
}

template <class FCT = functor_type>
auto flat(size_type i)
-> decltype(std::declval<FCT>()(std::declval<undecay_expression>().flat(i)))
{
return m_functor(m_e.flat(i));
}

template <class FCT = functor_type>
auto flat(size_type i) const
-> decltype(std::declval<FCT>()(std::declval<const undecay_expression>().flat(i)))
{
return m_functor(m_e.flat(i));
}

// The following functions are defined inline because otherwise signatures
// don't match on GCC.
template <class align, class requested_type = typename xexpression_type::value_type,
Expand Down Expand Up @@ -732,15 +747,15 @@ namespace xt
*
* @warning This method is meant for performance, for expressions with a dynamic
* number of dimensions (i.e. not known at compile time). Since it may have
* undefined behavior (see parameters), operator() should be prefered whenever
* undefined behavior (see parameters), operator() should be preferred whenever
* it is possible.
* @warning This method is NOT compatible with broadcasting, meaning the following
* code has undefined behavior:
* \code{.cpp}
* xt::xarray<double> a = {{0, 1}, {2, 3}};
* xt::xarray<double> b = {0, 1};
* auto fd = a + b;
* double res = fd.uncheked(0, 1);
* double res = fd.unchecked(0, 1);
* \endcode
*/
template <class D>
Expand Down
2 changes: 1 addition & 1 deletion include/xtensor/xmasked_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace xt

template <class CTD, class CTM>
struct xcontainer_inner_types<xmasked_view<CTD, CTM>>
{
{
using data_type = std::decay_t<CTD>;
using mask_type = std::decay_t<CTM>;
using base_value_type = typename data_type::value_type;
Expand Down
25 changes: 25 additions & 0 deletions include/xtensor/xoptional_assembly_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ namespace xt
template <class... Args>
const_reference periodic(Args... args) const;

reference flat(size_type args);
const_reference flat(size_type args) const;

template <class It>
reference element(It first, It last);
template <class It>
Expand Down Expand Up @@ -659,6 +662,28 @@ namespace xt
return const_reference(value().periodic(args...), has_value().periodic(args...));
}

/**
* Returns a reference to the element at the specified position
* of the underlying storage in the optional assembly.
* @param index index to underlying flat storage.
*/
template <class D>
inline auto xoptional_assembly_base<D>::flat(size_type i) -> reference
{
return reference(value().flat(i), has_value().flat(i));
}

/**
* Returns a constant reference to the element at the specified position
* of the underlying storage in the optional assembly.
* @param index index to underlying flat storage.
*/
template <class D>
inline auto xoptional_assembly_base<D>::flat(size_type i) const -> const_reference
{
return const_reference(value().flat(i), has_value().flat(i));
}

/**
* Returns a reference to the element at the specified position in the optional assembly.
* @param first iterator starting the sequence of indices
Expand Down
15 changes: 15 additions & 0 deletions include/xtensor/xscalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ namespace xt
reference data_element(size_type i) noexcept;
const_reference data_element(size_type i) const noexcept;

reference flat(size_type i) noexcept;
const_reference flat(size_type i) const noexcept;

template <class align, class simd = simd_value_type>
void store_simd(size_type i, const simd& e);
template <class align, class requested_type = value_type,
Expand Down Expand Up @@ -932,6 +935,18 @@ namespace xt
return m_value;
}

template <class CT>
inline auto xscalar<CT>::flat(size_type) noexcept -> reference
{
return m_value;
}

template <class CT>
inline auto xscalar<CT>::flat(size_type) const noexcept -> const_reference
{
return m_value;
}

template <class CT>
template <class align, class simd>
inline void xscalar<CT>::store_simd(size_type, const simd& e)
Expand Down
14 changes: 14 additions & 0 deletions include/xtensor/xstrided_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ namespace xt
reference data_element(size_type i);
const_reference data_element(size_type i) const;

reference flat(size_type i);
const_reference flat(size_type i) const;

using container_iterator = std::conditional_t<is_const,
typename storage_type::const_iterator,
typename storage_type::iterator>;
Expand Down Expand Up @@ -457,6 +460,17 @@ namespace xt
return storage()[i];
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::flat(size_type i) -> reference
{
return storage()[i];
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::flat(size_type i) const -> const_reference
{
return storage()[i];
}

template <class CT, class S, layout_type L, class FST>
inline auto xstrided_view<CT, S, L, FST>::storage_begin() -> storage_iterator
Expand Down
22 changes: 22 additions & 0 deletions include/xtensor/xview.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ namespace xt
template <class T = xexpression_type>
enable_simd_interface<T, const_reference> data_element(size_type i) const;

template <class T = xexpression_type>
enable_simd_interface<T, reference> flat(size_type i);

template <class T = xexpression_type>
enable_simd_interface<T, const_reference> flat(size_type i) const;

private:

// VS 2015 workaround (yes, really)
Expand Down Expand Up @@ -1488,6 +1494,22 @@ namespace xt
return m_e.data_element(data_offset() + i);
}

template <class CT, class... S>
template <class T>
inline auto xview<CT, S...>::flat(size_type i) -> enable_simd_interface<T, reference>
{
XTENSOR_ASSERT(is_contiguous());
return m_e.flat(data_offset() + i);
}

template <class CT, class... S>
template <class T>
inline auto xview<CT, S...>::flat(size_type i) const -> enable_simd_interface<T, const_reference>
{
XTENSOR_ASSERT(is_contiguous());
return m_e.flat(data_offset() + i);
}

template <class CT, class... S>
template <class... Args>
inline auto xview<CT, S...>::make_index_sequence(Args...) const noexcept
Expand Down
20 changes: 20 additions & 0 deletions test/test_xarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,26 @@ namespace xt
EXPECT_EQ(a, b);
}

TEST(xarray, flat)
{
{
xt::xarray<size_t, xt::layout_type::row_major> a = {{0,1,2}, {3,4,5}};
xt::xarray<size_t, xt::layout_type::row_major> b = {{0,1,2}, {30,40,50}};
a.flat(3) = 30;
a.flat(4) = 40;
a.flat(5) = 50;
EXPECT_EQ(a, b);
}
{
xt::xarray<size_t, xt::layout_type::column_major> a = {{0,1,2}, {3,4,5}};
xt::xarray<size_t, xt::layout_type::column_major> b = {{0,1,2}, {30,40,50}};
a.flat(1) = 30;
a.flat(3) = 40;
a.flat(5) = 50;
EXPECT_EQ(a, b);
}
}

TEST(xarray, in_bounds)
{
xt::xarray<size_t> a = {{0,1,2}, {3,4,5}};
Expand Down
8 changes: 8 additions & 0 deletions test/test_xfunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ namespace xt
}
}

TEST(xfunction, flat)
{
xfunction_features f;
int a = (f.m_a + f.m_a).flat(0);
int b = f.m_a.flat(0) + f.m_a.flat(0);
EXPECT_EQ(a, b);
}

TEST(xfunction, in_bounds)
{
xfunction_features f;
Expand Down
20 changes: 20 additions & 0 deletions test/test_xtensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,26 @@ namespace xt
EXPECT_EQ(a, b);
}

TEST(xtensor, flat)
{
{
xt::xtensor<size_t, 2, xt::layout_type::row_major> a = {{0,1,2}, {3,4,5}};
xt::xtensor<size_t, 2, xt::layout_type::row_major> b = {{0,1,2}, {30,40,50}};
a.flat(3) = 30;
a.flat(4) = 40;
a.flat(5) = 50;
EXPECT_EQ(a, b);
}
{
xt::xtensor<size_t, 2, xt::layout_type::column_major> a = {{0,1,2}, {3,4,5}};
xt::xtensor<size_t, 2, xt::layout_type::column_major> b = {{0,1,2}, {30,40,50}};
a.flat(1) = 30;
a.flat(3) = 40;
a.flat(5) = 50;
EXPECT_EQ(a, b);
}
}

TEST(xtensor, in_bounds)
{
xt::xtensor<size_t,2> a = {{0,1,2}, {3,4,5}};
Expand Down
Loading

0 comments on commit b5c9286

Please sign in to comment.