Skip to content

Commit

Permalink
Merge pull request xtensor-stack#2374 from tdegeus/average2
Browse files Browse the repository at this point in the history
average: fixing overload issue for axis argument
  • Loading branch information
JohanMabille authored May 20, 2021
2 parents 6504ecd + 4774c4d commit 62d3f3a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/xmath.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ Mathematical functions
+-----------------------------------------------+---------------------------------------------------------------------+
| :ref:`mean <mean-function-reference>` | mean of elements over given axes |
+-----------------------------------------------+---------------------------------------------------------------------+
| :ref:`average <average-function-reference>` | weighted average along the specified axis |
+-----------------------------------------------+---------------------------------------------------------------------+
| :ref:`variance <variance-function-reference>` | variance of elements over given axes |
+-----------------------------------------------+---------------------------------------------------------------------+
| :ref:`stddev <stddev-function-reference>` | standard deviation of elements over given axes |
Expand Down
7 changes: 7 additions & 0 deletions include/xtensor/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,13 @@ namespace detail {
return sum<T>(std::forward<E>(e) * std::move(weights_view), std::move(ax), ev) / std::move(scl);
}

template <class T = void, class E, class W, class X, class EVS = DEFAULT_STRATEGY_REDUCERS,
XTL_REQUIRES(is_reducer_options<EVS>, xtl::is_integral<X>)>
inline auto average(E&& e, W&& weights, X axis, EVS ev = EVS())
{
return average(std::forward<E>(e), std::forward<W>(weights), {axis}, std::forward<EVS>(ev));
}

template <class T = void, class E, class W, class X, std::size_t N, class EVS = DEFAULT_STRATEGY_REDUCERS>
inline auto average(E&& e, W&& weights, const X(&axes)[N], EVS ev = EVS())
{
Expand Down
35 changes: 35 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,41 @@ namespace xt
EXPECT_EQ(res5[0], 8.0);
}

/********************
* Mean and average *
********************/

TEST(xmath, mean)
{
xt::xtensor<double,2> v = {{1.0, 1.0, 1.0}, {2.0, 2.0, 2.0}};
xt::xtensor<double,1> m0 = {1.5, 1.5, 1.5};
xt::xtensor<double,1> m1 = {1.0, 2.0};
double m = 9.0 / 6.0;

EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, 0), m0)));
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, {0}), m0)));
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, 1), m1)));
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, {1}), m1)));
EXPECT_EQ(xt::mean(v)(), m);
EXPECT_EQ(xt::mean(v, {0, 1})(), m);
}

TEST(xmath, average)
{
xt::xtensor<double,2> v = {{1.0, 1.0, 1.0}, {2.0, 2.0, 2.0}};
xt::xtensor<double,2> w = {{2.0, 2.0, 2.0}, {2.0, 2.0, 2.0}};
xt::xtensor<double,1> m0 = {1.5, 1.5, 1.5};
xt::xtensor<double,1> m1 = {1.0, 2.0};
double m = 9.0 / 6.0;

EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, 0), m0)));
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, {0}), m0)));
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, 1), m1)));
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, {1}), m1)));
EXPECT_EQ(xt::average(v, w)(), m);
EXPECT_EQ(xt::average(v, w, {0, 1})(), m);
}

/************************
* Linear interpolation *
************************/
Expand Down

0 comments on commit 62d3f3a

Please sign in to comment.