Skip to content

Commit

Permalink
* added missing greater-comperator in lex. comparison
Browse files Browse the repository at this point in the history
* removed handling fixed-shape tensors (this should should be done in a separate pr)
* extended tests st. they do not rely on fixed-shape tensors
  • Loading branch information
DerThorsten committed Jun 2, 2021
1 parent 597a139 commit 5f133ae
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 38 deletions.
44 changes: 6 additions & 38 deletions include/xtensor/xassign.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <algorithm>
#include <type_traits>
#include <utility>
#include <functional>

#include <xtl/xcomplex.hpp>
#include <xtl/xsequence.hpp>
Expand Down Expand Up @@ -413,60 +414,27 @@ namespace xt
base_type::assign_data(e1, e2, trivial_broadcast);
}

namespace detail
{

template<class SHAPE>
struct select_shape{
using type = SHAPE;
};

template<std::size_t ... X>
struct select_shape<xt::fixed_shape< X ...>>{
using type = std::array<std::size_t, sizeof ... (X)>;
};

template<class SHAPE>
using select_shape_t = typename select_shape<SHAPE>::type;


template<class EXP, class T, class SHAPE>
struct select_tmp_type{
using type = typename EXP::temporary_type;
};

template<class EXP, class T, std::size_t ... X>
struct select_tmp_type<EXP, T,xt::fixed_shape< X ...>>{
using type = xt::xtensor<T, sizeof ... (X)>;
};

template<class EXP, class T, class SHAPE>
using select_tmp_type_t = typename select_tmp_type<EXP,T, SHAPE>::type;

};

template <class Tag>
template <class E1, class E2>
inline void xexpression_assigner<Tag>::computed_assign(xexpression<E1>& e1, const xexpression<E2>& e2)
{
using shape_type = detail::select_shape_t< std::decay_t<typename E1::shape_type>>;
using shape_type = typename E1::shape_type;
using comperator_type = std::greater<typename shape_type::value_type>;

using size_type = typename E1::size_type;

E1& de1 = e1.derived_cast();
const E2& de2 = e2.derived_cast();

size_type dim1 = de1.dimension();
size_type dim2 = de2.dimension();
shape_type shape = uninitialized_shape<shape_type>(dim2);

bool trivial_broadcast = de2.broadcast_shape(shape, true);

auto && de1_shape = de1.shape();
// we cannot simply call de1_shape.begin()/del1_shape.end() since this can be a pointer
if (dim2 > de1.dimension() || std::lexicographical_compare(shape.begin(), shape.end(), std::begin(de1_shape), std::begin(de1_shape) + dim1))
if (dim2 > de1.dimension() || std::lexicographical_compare(shape.begin(), shape.end(), de1_shape.begin(), de1_shape.end(), comperator_type()))
{
using temporary_type = detail::select_tmp_type_t<E1, typename E1::value_type, typename E1::shape_type>;
temporary_type tmp(shape);
typename E1::temporary_type tmp(shape);
base_type::assign_data(tmp, e2, trivial_broadcast);
de1.assign_temporary(std::move(tmp));
}
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ set(XTENSOR_TESTS
main.cpp
test_xaccumulator.cpp
test_xadapt.cpp
test_xassign.cpp
test_xaxis_iterator.cpp
test_xaxis_slice_iterator.cpp
test_xbuffer_adaptor.cpp
Expand Down
150 changes: 150 additions & 0 deletions test/test_xassign.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/***************************************************************************
* Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
* Copyright (c) QuantStack *
* *
* Distributed under the terms of the BSD 3-Clause License. *
* *
* The full license is in the file LICENSE, distributed with this software. *
****************************************************************************/

#include "gtest/gtest.h"
#include "xtensor/xarray.hpp"
#include "xtensor/xtensor.hpp"

#include "xtensor/xassign.hpp"
#include "xtensor/xnoalias.hpp"
#include "test_common.hpp"

#include <type_traits>
#include <vector>


// a dummy shape *not derived* from std::vector but compatible
template<class T>
class my_vector
{
private:
using vector_type = std::vector<T>;
public:
my_vector(){}
using value_type = T;
using size_type = typename vector_type::size_type;
template<class U>
my_vector(std::initializer_list<U> vals)
: m_data(vals.begin(), vals.end())
{

}
my_vector(const std::size_t size, const T & val = T())
: m_data(size, val)
{
}
auto resize(const std::size_t size)
{
return m_data.resize(size);
}
auto size()const
{
return m_data.size();
}
auto cend()const
{
return m_data.cend();
}
auto cbegin()const
{
return m_data.cbegin();
}
auto end()
{
return m_data.end();
}
auto end()const
{
return m_data.end();
}
auto begin()
{
return m_data.begin();
}
auto begin()const
{
return m_data.begin();
}
auto empty()const
{
return m_data.empty();
}
auto & back()
{
return m_data.back();
}
const auto & back()const
{
return m_data.back();
}
auto & front()
{
return m_data.front();
}
const auto & front()const
{
return m_data.front();
}
auto & operator[](const std::size_t i)
{
return m_data[i];
}
const auto & operator[](const std::size_t i)const
{
return m_data[i];
}
private:
std::vector<T> m_data;
};


namespace xt
{

template <class T, class C_T>
struct rebind_container<T, my_vector<C_T>>
{
using type = my_vector<T>;
};

TEST(xassign, mix_shape_types)
{
{
// xarray like with custom shape
using my_xarray = xt::xarray_container<
std::vector<int>,
xt::layout_type::row_major,
my_vector<std::size_t>
>;

auto a = my_xarray::from_shape({1,3});
auto b = xt::xtensor<int,2>::from_shape({2,3});
xt::noalias(a) += b;
EXPECT_EQ(a.dimension(), 2);
EXPECT_EQ(a.shape(0), 2);
EXPECT_EQ(a.shape(1), 3);
}
{
// xarray like with custom shape
using my_xarray = xt::xarray_container<
std::vector<int>,
xt::layout_type::row_major,
my_vector<std::size_t>
>;

auto a = my_xarray::from_shape({3});
auto b = xt::xtensor<int,2>::from_shape({2,3});
xt::noalias(a) += b;
EXPECT_EQ(a.dimension(), 2);
EXPECT_EQ(a.shape(0), 2);
EXPECT_EQ(a.shape(1), 3);
}

}
}

0 comments on commit 5f133ae

Please sign in to comment.