diff --git a/docs/source/api/xstrides.rst b/docs/source/api/xstrides.rst new file mode 100644 index 000000000..901ff63be --- /dev/null +++ b/docs/source/api/xstrides.rst @@ -0,0 +1,16 @@ +.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht + + Distributed under the terms of the BSD 3-Clause License. + + The full license is in the file LICENSE, distributed with this software. + +xshape +====== + +Defined in ``xtensor/xstride.hpp`` + +.. doxygenfunction:: auto strides(const E& e, xt::stride_type type) + :project: xtensor + +.. doxygenfunction:: auto strides(const E& e, S axis, xt::stride_type type) + :project: xtensor diff --git a/docs/source/numpy-differences.rst b/docs/source/numpy-differences.rst index 10b3082fa..f51748ef2 100644 --- a/docs/source/numpy-differences.rst +++ b/docs/source/numpy-differences.rst @@ -98,6 +98,20 @@ Strides Strided containers of xtensor and numpy having the same exact memory layout may have different strides when accessing them through the ``strides`` attribute. The reason is an optimization in xtensor, which is to set the strides to ``0`` in dimensions of length ``1``, which simplifies the implementation of broadcasting of universal functions. +.. tip:: + + Use the free function ``xt::strides`` to switch between representations. + + .. code-block:: cpp + + xt::strides(a); // strides of ``a`` corresponding to storage + xt::strides(a, xt::stride_type::normal); // same + + xt::strides(a, xt::stride_type::internal); // ``== a.strides()`` + + xt::strides(a, xt::stride_type::bytes) // strides in bytes, as in numpy + + Array indices ------------- diff --git a/include/xtensor/xstrides.hpp b/include/xtensor/xstrides.hpp index 65eb956e0..ff8a5623a 100644 --- a/include/xtensor/xstrides.hpp +++ b/include/xtensor/xstrides.hpp @@ -142,6 +142,127 @@ namespace xt return begin; } + + /*********** + * strides * + ***********/ + + namespace detail + { + template <class return_type, class S, class T, class D> + inline return_type compute_stride_impl(layout_type layout, const S& shape, T axis, D default_stride) + { + if (layout == layout_type::row_major) + { + return std::accumulate( + shape.cbegin() + axis + 1, + shape.cend(), + static_cast<return_type>(1), + std::multiplies<return_type>() + ); + } + if (layout == layout_type::column_major) + { + return std::accumulate( + shape.cbegin(), + shape.cbegin() + axis, + static_cast<return_type>(1), + std::multiplies<return_type>() + ); + } + return default_stride; + } + } + + /** + * @ingroup strides + * @brief strides_type + * + * Choose stride type + */ + enum class stride_type + { + internal = 0, ///< As used internally (with `stride(axis) == 0` if `shape(axis) == 1`) + normal = 1, ///< Normal stride corresponding to storage. + bytes = 2, ///< Normal stride in bytes. + }; + + /** + * @ingroup strides + * @brief strides + * + * Get strides of an object. + * @param a an array + * @return array + */ + template <class E> + inline auto strides(const E& e, stride_type type = stride_type::normal) noexcept + { + using strides_type = typename E::strides_type; + using return_type = typename strides_type::value_type; + strides_type ret = e.strides(); + auto shape = e.shape(); + + if (type == stride_type::internal) + { + return ret; + } + + for (std::size_t i = 0; i < ret.size(); ++i) + { + if (shape[i] == 1) + { + ret[i] = detail::compute_stride_impl<return_type>(e.layout(), shape, i, ret[i]); + } + } + + if (type == stride_type::bytes) + { + return_type f = static_cast<return_type>(sizeof(typename E::value_type)); + std::for_each(ret.begin(), ret.end(), [f](auto& c){ c *= f; }); + } + + return ret; + } + + /** + * @ingroup strides + * @brief strides + * + * Get stride of an object along an axis. + * @param a an array + * @return integer + */ + template <class E> + inline auto strides(const E& e, std::size_t axis, stride_type type = stride_type::normal) noexcept + { + using strides_type = typename E::strides_type; + using return_type = typename strides_type::value_type; + + return_type ret = e.strides()[axis]; + + if (type == stride_type::internal) + { + return ret; + } + + if (ret == 0) + { + if (e.shape(axis) == 1) + { + ret = detail::compute_stride_impl<return_type>(e.layout(), e.shape(), axis, ret); + } + } + + if (type == stride_type::bytes) + { + return_type f = static_cast<return_type>(sizeof(typename E::value_type)); + ret *= f; + } + + return ret; + } + /****************** * Implementation * ******************/ diff --git a/test/test_xstrides.cpp b/test/test_xstrides.cpp index 80f995099..faa61182d 100644 --- a/test/test_xstrides.cpp +++ b/test/test_xstrides.cpp @@ -50,6 +50,92 @@ namespace xt EXPECT_TRUE(t5); } + TEST(xstrides, free_function_2d_row_major) + { + xt::xarray<int, xt::layout_type::row_major> a = xt::ones<int>({1, 3}); + using stype = std::vector<std::ptrdiff_t>; + std::ptrdiff_t sof = sizeof(int); + + EXPECT_EQ(xt::strides(a), stype({3, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({3, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({0, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({3 * sof, sof})); + + EXPECT_TRUE(xt::strides(a, 0) == 3); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 3); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 0); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == 3 * sof); + + EXPECT_TRUE(xt::strides(a, 1) == 1); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 1); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 1); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == sof); + } + + TEST(xstrides, free_function_4d_row_major) + { + xt::xarray<int, xt::layout_type::row_major> a = xt::ones<int>({5, 4, 1, 4}); + using stype = std::vector<std::ptrdiff_t>; + std::ptrdiff_t sof = sizeof(int); + + EXPECT_EQ(xt::strides(a), stype({16, 4, 4, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({16, 4, 4, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({16, 4, 0, 1})); + EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({16 * sof, 4 * sof, 4 * sof, 1 * sof})); + + EXPECT_TRUE(xt::strides(a, 0) == 16); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 16); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 16); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == 16 * sof); + + EXPECT_TRUE(xt::strides(a, 1) == 4); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 4); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 4); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == 4 * sof); + + EXPECT_TRUE(xt::strides(a, 2) == 4); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::normal) == 4); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::internal) == 0); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::bytes) == 4 * sof); + + EXPECT_TRUE(xt::strides(a, 3) == 1); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::normal) == 1); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::internal) == 1); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::bytes) == sof); + } + + TEST(xstrides, free_function_4d_column_major) + { + xt::xarray<int, xt::layout_type::column_major> a = xt::ones<int>({5, 4, 1, 4}); + using stype = std::vector<std::ptrdiff_t>; + std::ptrdiff_t sof = sizeof(int); + + EXPECT_EQ(xt::strides(a), stype({1, 5, 20, 20})); + EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({1, 5, 20, 20})); + EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({1, 5, 0, 20})); + EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({sof, 5 * sof, 20 * sof, 20 * sof})); + + EXPECT_TRUE(xt::strides(a, 0) == 1); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 1); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 1); + EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == sof); + + EXPECT_TRUE(xt::strides(a, 1) == 5); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 5); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 5); + EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == 5 * sof); + + EXPECT_TRUE(xt::strides(a, 2) == 20); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::normal) == 20); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::internal) == 0); + EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::bytes) == 20 * sof); + + EXPECT_TRUE(xt::strides(a, 3) == 20); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::normal) == 20); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::internal) == 20); + EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::bytes) == 20 * sof); + } + TEST(xstrides, unravel_from_strides) { SUBCASE("row_major strides")