diff --git a/docs/source/numpy.rst b/docs/source/numpy.rst index 503c0ace4..5f5fc7699 100644 --- a/docs/source/numpy.rst +++ b/docs/source/numpy.rst @@ -136,8 +136,7 @@ See :any:`numpy indexing ` page. +====================================================================+====================================================================+ | ``a[3, 2]`` | ``a(3, 2)`` | +--------------------------------------------------------------------+--------------------------------------------------------------------+ -| :any:`a.flat[4] ` || ``a[4]`` | -| || ``a(4)`` | +| :any:`a.flat[4] ` | ``a.flat(4)`` | +--------------------------------------------------------------------+--------------------------------------------------------------------+ | ``a[3]`` || ``xt::view(a, 3, xt::all())`` | | || ``xt::row(a, 3)`` | diff --git a/include/xtensor/xaccessible.hpp b/include/xtensor/xaccessible.hpp index 600ddd0b2..2f6af2c72 100644 --- a/include/xtensor/xaccessible.hpp +++ b/include/xtensor/xaccessible.hpp @@ -125,7 +125,7 @@ namespace xt /************************************ * xconst_accessible implementation * ************************************/ - + /** * Returns the size of the expression. */ @@ -211,7 +211,7 @@ namespace xt normalize_periodic(derived_cast().shape(), args...); return derived_cast()(static_cast(args)...); } - + /** * Returns ``true`` only if the the specified position is a valid entry in the expression. * @param args a list of indices specifying the position in the expression. @@ -246,7 +246,7 @@ namespace xt template template inline auto xaccessible::at(Args... args) -> reference - { + { check_access(derived_cast().shape(), static_cast(args)...); return derived_cast().operator()(args...); } @@ -293,7 +293,6 @@ namespace xt return derived_cast()(static_cast(args)...); } - template inline auto xaccessible::derived_cast() noexcept -> derived_type& { diff --git a/include/xtensor/xcontainer.hpp b/include/xtensor/xcontainer.hpp index 7a116804c..df6ce6995 100644 --- a/include/xtensor/xcontainer.hpp +++ b/include/xtensor/xcontainer.hpp @@ -174,6 +174,9 @@ namespace xt reference data_element(size_type i); const_reference data_element(size_type i) const; + reference flat(size_type i); + const_reference flat(size_type i) const; + template using simd_return_type = xt_simd::simd_return_type; @@ -643,6 +646,18 @@ namespace xt return storage()[i]; } + template + inline auto xcontainer::flat(size_type i) -> reference + { + return storage()[i]; + } + + template + inline auto xcontainer::flat(size_type i) const -> const_reference + { + return storage()[i]; + } + /*************** * stepper api * ***************/ diff --git a/include/xtensor/xdynamic_view.hpp b/include/xtensor/xdynamic_view.hpp index 5c51d35b5..01a37a1d6 100644 --- a/include/xtensor/xdynamic_view.hpp +++ b/include/xtensor/xdynamic_view.hpp @@ -181,6 +181,9 @@ namespace xt template const_reference unchecked(Args... args) const; + reference flat(size_type index); + const_reference flat(size_type index) const; + using base_type::operator[]; using base_type::at; using base_type::periodic; @@ -450,6 +453,18 @@ namespace xt return base_type::storage()[static_cast(offset)]; } + template + inline auto xdynamic_view::flat(size_type i) -> reference + { + return base_type::storage()[data_offset() + i]; + } + + template + inline auto xdynamic_view::flat(size_type i) const -> const_reference + { + return base_type::storage()[data_offset() + i]; + } + template template inline auto xdynamic_view::element(It first, It last) -> reference diff --git a/include/xtensor/xfunction.hpp b/include/xtensor/xfunction.hpp index 5f1c4d7e0..49aa7a900 100644 --- a/include/xtensor/xfunction.hpp +++ b/include/xtensor/xfunction.hpp @@ -307,6 +307,8 @@ namespace xt const_reference data_element(size_type i) const; + const_reference flat(size_type i) const; + template ::type> operator value_type() const; @@ -583,6 +585,20 @@ namespace xt return access_impl(std::make_index_sequence(), static_cast(args)...); } + /** + * @name Data + */ + /** + * Returns a constant reference to the element at the specified position of the underlying + * contiguous storage of the function. + * @param index index to underlying flat storage. + */ + template + inline auto xfunction::flat(size_type index) const -> const_reference + { + return data_element_impl(std::make_index_sequence(), index); + } + /** * Returns a constant reference to the element at the specified position in the expression. * @param args a list of indices specifying the position in the expression. Indices diff --git a/include/xtensor/xfunctor_view.hpp b/include/xtensor/xfunctor_view.hpp index 9542c7065..b2b87921b 100644 --- a/include/xtensor/xfunctor_view.hpp +++ b/include/xtensor/xfunctor_view.hpp @@ -178,6 +178,7 @@ namespace xt using accessible_base::at; using accessible_base::operator[]; using accessible_base::periodic; + using accessible_base::in_bounds; xexpression_type& expression() noexcept; @@ -203,6 +204,20 @@ namespace xt return m_functor(m_e.data_element(i)); } + template + auto flat(size_type i) + -> decltype(std::declval()(std::declval().flat(i))) + { + return m_functor(m_e.flat(i)); + } + + template + auto flat(size_type i) const + -> decltype(std::declval()(std::declval().flat(i))) + { + return m_functor(m_e.flat(i)); + } + // The following functions are defined inline because otherwise signatures // don't match on GCC. template a = {{0, 1}, {2, 3}}; * xt::xarray b = {0, 1}; * auto fd = a + b; - * double res = fd.uncheked(0, 1); + * double res = fd.unchecked(0, 1); * \endcode */ template diff --git a/include/xtensor/xmasked_view.hpp b/include/xtensor/xmasked_view.hpp index acd51ecd5..f0bebcc5b 100644 --- a/include/xtensor/xmasked_view.hpp +++ b/include/xtensor/xmasked_view.hpp @@ -37,7 +37,7 @@ namespace xt template struct xcontainer_inner_types> - { + { using data_type = std::decay_t; using mask_type = std::decay_t; using base_value_type = typename data_type::value_type; diff --git a/include/xtensor/xoptional_assembly_base.hpp b/include/xtensor/xoptional_assembly_base.hpp index 6e7c7ec10..67d34b612 100644 --- a/include/xtensor/xoptional_assembly_base.hpp +++ b/include/xtensor/xoptional_assembly_base.hpp @@ -177,6 +177,9 @@ namespace xt template const_reference periodic(Args... args) const; + reference flat(size_type args); + const_reference flat(size_type args) const; + template reference element(It first, It last); template @@ -659,6 +662,28 @@ namespace xt return const_reference(value().periodic(args...), has_value().periodic(args...)); } + /** + * Returns a reference to the element at the specified position + * of the underlying storage in the optional assembly. + * @param index index to underlying flat storage. + */ + template + inline auto xoptional_assembly_base::flat(size_type i) -> reference + { + return reference(value().flat(i), has_value().flat(i)); + } + + /** + * Returns a constant reference to the element at the specified position + * of the underlying storage in the optional assembly. + * @param index index to underlying flat storage. + */ + template + inline auto xoptional_assembly_base::flat(size_type i) const -> const_reference + { + return const_reference(value().flat(i), has_value().flat(i)); + } + /** * Returns a reference to the element at the specified position in the optional assembly. * @param first iterator starting the sequence of indices diff --git a/include/xtensor/xscalar.hpp b/include/xtensor/xscalar.hpp index 83a1cb008..adb8591bf 100644 --- a/include/xtensor/xscalar.hpp +++ b/include/xtensor/xscalar.hpp @@ -281,6 +281,9 @@ namespace xt reference data_element(size_type i) noexcept; const_reference data_element(size_type i) const noexcept; + reference flat(size_type i) noexcept; + const_reference flat(size_type i) const noexcept; + template void store_simd(size_type i, const simd& e); template + inline auto xscalar::flat(size_type) noexcept -> reference + { + return m_value; + } + + template + inline auto xscalar::flat(size_type) const noexcept -> const_reference + { + return m_value; + } + template template inline void xscalar::store_simd(size_type, const simd& e) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index c6b05b6b9..7a6610595 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -266,6 +266,9 @@ namespace xt reference data_element(size_type i); const_reference data_element(size_type i) const; + reference flat(size_type i); + const_reference flat(size_type i) const; + using container_iterator = std::conditional_t; @@ -457,6 +460,17 @@ namespace xt return storage()[i]; } + template + inline auto xstrided_view::flat(size_type i) -> reference + { + return storage()[i]; + } + + template + inline auto xstrided_view::flat(size_type i) const -> const_reference + { + return storage()[i]; + } template inline auto xstrided_view::storage_begin() -> storage_iterator diff --git a/include/xtensor/xview.hpp b/include/xtensor/xview.hpp index 49e9c8ac3..eac650fc3 100644 --- a/include/xtensor/xview.hpp +++ b/include/xtensor/xview.hpp @@ -651,6 +651,12 @@ namespace xt template enable_simd_interface data_element(size_type i) const; + template + enable_simd_interface flat(size_type i); + + template + enable_simd_interface flat(size_type i) const; + private: // VS 2015 workaround (yes, really) @@ -1488,6 +1494,22 @@ namespace xt return m_e.data_element(data_offset() + i); } + template + template + inline auto xview::flat(size_type i) -> enable_simd_interface + { + XTENSOR_ASSERT(is_contiguous()); + return m_e.flat(data_offset() + i); + } + + template + template + inline auto xview::flat(size_type i) const -> enable_simd_interface + { + XTENSOR_ASSERT(is_contiguous()); + return m_e.flat(data_offset() + i); + } + template template inline auto xview::make_index_sequence(Args...) const noexcept diff --git a/test/test_xarray.cpp b/test/test_xarray.cpp index 2a23de052..d540069f2 100644 --- a/test/test_xarray.cpp +++ b/test/test_xarray.cpp @@ -324,6 +324,26 @@ namespace xt EXPECT_EQ(a, b); } + TEST(xarray, flat) + { + { + xt::xarray a = {{0,1,2}, {3,4,5}}; + xt::xarray b = {{0,1,2}, {30,40,50}}; + a.flat(3) = 30; + a.flat(4) = 40; + a.flat(5) = 50; + EXPECT_EQ(a, b); + } + { + xt::xarray a = {{0,1,2}, {3,4,5}}; + xt::xarray b = {{0,1,2}, {30,40,50}}; + a.flat(1) = 30; + a.flat(3) = 40; + a.flat(5) = 50; + EXPECT_EQ(a, b); + } + } + TEST(xarray, in_bounds) { xt::xarray a = {{0,1,2}, {3,4,5}}; diff --git a/test/test_xfunction.cpp b/test/test_xfunction.cpp index ddd9ce824..4545f2d7b 100644 --- a/test/test_xfunction.cpp +++ b/test/test_xfunction.cpp @@ -271,6 +271,14 @@ namespace xt } } + TEST(xfunction, flat) + { + xfunction_features f; + int a = (f.m_a + f.m_a).flat(0); + int b = f.m_a.flat(0) + f.m_a.flat(0); + EXPECT_EQ(a, b); + } + TEST(xfunction, in_bounds) { xfunction_features f; diff --git a/test/test_xtensor.cpp b/test/test_xtensor.cpp index 79e913374..1d8aedbdb 100644 --- a/test/test_xtensor.cpp +++ b/test/test_xtensor.cpp @@ -343,6 +343,26 @@ namespace xt EXPECT_EQ(a, b); } + TEST(xtensor, flat) + { + { + xt::xtensor a = {{0,1,2}, {3,4,5}}; + xt::xtensor b = {{0,1,2}, {30,40,50}}; + a.flat(3) = 30; + a.flat(4) = 40; + a.flat(5) = 50; + EXPECT_EQ(a, b); + } + { + xt::xtensor a = {{0,1,2}, {3,4,5}}; + xt::xtensor b = {{0,1,2}, {30,40,50}}; + a.flat(1) = 30; + a.flat(3) = 40; + a.flat(5) = 50; + EXPECT_EQ(a, b); + } + } + TEST(xtensor, in_bounds) { xt::xtensor a = {{0,1,2}, {3,4,5}}; diff --git a/test/test_xview.cpp b/test/test_xview.cpp index fba27e703..2047c86eb 100644 --- a/test/test_xview.cpp +++ b/test/test_xview.cpp @@ -1464,6 +1464,17 @@ namespace xt EXPECT_EQ(a, b); } + TEST(xview, flat) + { + xt::xtensor a = {{0,1,2}, {3,4,5}}; + xt::xtensor b = {{0,1,2}, {30,40,50}}; + auto view = xt::view(a, 1, xt::all()); + view.flat(0) = 30; + view.flat(1) = 40; + view.flat(2) = 50; + EXPECT_EQ(a, b); + } + TEST(xview, in_bounds) { xt::xtensor a = {{0,1,2}, {3,4,5}};