Skip to content

Commit

Permalink
[test] Speed up the reduce_by_segment.pass test (uxlfoundation#1546)
Browse files Browse the repository at this point in the history
We save time in reduce_by_segment.pass by making the following changes:
- For each test, check if we are using the default predicate and / or operator and test that API if we do. Previously, we would test with and without providing a customer operator and predicate for each test run. This led to duplicate testing.
- Some test scenarios deemed unnecessary have been removed.

---------

Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 authored Apr 30, 2024
1 parent 29f20db commit f6a8c40
Showing 1 changed file with 44 additions and 61 deletions.
105 changes: 44 additions & 61 deletions test/parallel_api/numeric/numeric.ops/reduce_by_segment.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,51 +148,34 @@ DEFINE_TEST_2(test_reduce_by_segment, BinaryPredicate, BinaryOperation)
TestDataTransfer<UDTKind::eRes2, Size> host_res(*this, n);

typedef typename ::std::iterator_traits<Iterator1>::value_type KeyT;
typedef typename ::std::iterator_traits<Iterator2>::value_type ValT;

// call algorithm with no optional arguments
initialize_data(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n);
update_data(host_keys, host_vals, host_res_keys, host_res);

auto new_policy = make_new_policy<new_kernel_name<Policy, 0>>(exec);
auto res1 =
oneapi::dpl::reduce_by_segment(new_policy, keys_first, keys_last, vals_first, key_res_first, val_res_first);
exec.queue().wait_and_throw();

retrieve_data(host_keys, host_vals, host_res_keys, host_res);
size_t segments_key_ret1 = ::std::distance(key_res_first, res1.first);
size_t segments_val_ret1 = ::std::distance(val_res_first, res1.second);
check_values(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n, segments_key_ret1,
segments_val_ret1);

// call algorithm with predicate
initialize_data(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n);
update_data(host_keys, host_vals, host_res_keys, host_res);

auto new_policy2 = make_new_policy<new_kernel_name<Policy, 1>>(exec);
auto res2 = oneapi::dpl::reduce_by_segment(new_policy2, keys_first, keys_last, vals_first, key_res_first,
val_res_first, BinaryPredicate());
exec.queue().wait_and_throw();

retrieve_data(host_keys, host_vals, host_res_keys, host_res);
size_t segments_key_ret2 = ::std::distance(key_res_first, res2.first);
size_t segments_val_ret2 = ::std::distance(val_res_first, res2.second);
check_values(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n, segments_key_ret2,
segments_val_ret2, BinaryPredicate());

// call algorithm with predicate and operator
initialize_data(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n);
update_data(host_keys, host_vals, host_res_keys, host_res);

auto new_policy3 = make_new_policy<new_kernel_name<Policy, 2>>(exec);
auto res3 = oneapi::dpl::reduce_by_segment(new_policy3, keys_first, keys_last, vals_first, key_res_first,
val_res_first, BinaryPredicate(), BinaryOperation());
std::pair<Iterator3, Iterator4> res;
if constexpr (std::is_same_v<std::equal_to<KeyT>, std::decay_t<BinaryPredicate>> &&
std::is_same_v<std::plus<ValT>, std::decay_t<BinaryOperation>>)
{
res = oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first, val_res_first);
}
else if constexpr (std::is_same_v<std::plus<ValT>, std::decay_t<BinaryOperation>>)
{
res = oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first, val_res_first,
BinaryPredicate());
}
else
{
res = oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first, val_res_first,
BinaryPredicate(), BinaryOperation());
}
exec.queue().wait_and_throw();

retrieve_data(host_keys, host_vals, host_res_keys, host_res);
size_t segments_key_ret3 = ::std::distance(key_res_first, res3.first);
size_t segments_val_ret3 = ::std::distance(val_res_first, res3.second);
check_values(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n, segments_key_ret3,
segments_val_ret3, BinaryPredicate(), BinaryOperation());
size_t segments_key_ret = ::std::distance(key_res_first, res.first);
size_t segments_val_ret = ::std::distance(val_res_first, res.second);
check_values(host_keys.get(), host_vals.get(), host_res_keys.get(), host_res.get(), n, segments_key_ret,
segments_val_ret, BinaryPredicate(), BinaryOperation());
}
#endif

Expand All @@ -208,30 +191,31 @@ DEFINE_TEST_2(test_reduce_by_segment, BinaryPredicate, BinaryOperation)
operator()(Policy&& exec, Iterator1 keys_first, Iterator1 keys_last, Iterator2 vals_first, Iterator2 vals_last,
Iterator3 key_res_first, Iterator3 key_res_last, Iterator4 val_res_first, Iterator4 val_res_last, Size n)
{
// call algorithm with no optional arguments
initialize_data(keys_first, vals_first, key_res_first, val_res_first, n);
auto res1 =
oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first, val_res_first);
size_t segments_key_ret1 = ::std::distance(key_res_first, res1.first);
size_t segments_val_ret1 = ::std::distance(val_res_first, res1.second);
check_values(keys_first, vals_first, key_res_first, val_res_first, n, segments_key_ret1, segments_val_ret1);
typedef typename ::std::iterator_traits<Iterator1>::value_type KeyT;
typedef typename ::std::iterator_traits<Iterator2>::value_type ValT;

// call algorithm with predicate
initialize_data(keys_first, vals_first, key_res_first, val_res_first, n);
auto res2 = oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first,
val_res_first, BinaryPredicate());
size_t segments_key_ret2 = ::std::distance(key_res_first, res2.first);
size_t segments_val_ret2 = ::std::distance(val_res_first, res2.second);
check_values(keys_first, vals_first, key_res_first, val_res_first, n, segments_key_ret2, segments_val_ret2,
BinaryPredicate());

// call algorithm with predicate and operator
initialize_data(keys_first, vals_first, key_res_first, val_res_first, n);
auto res3 = oneapi::dpl::reduce_by_segment(exec, keys_first, keys_last, vals_first, key_res_first,
val_res_first, BinaryPredicate(), BinaryOperation());
size_t segments_key_ret3 = ::std::distance(key_res_first, res3.first);
size_t segments_val_ret3 = ::std::distance(val_res_first, res3.second);
check_values(keys_first, vals_first, key_res_first, val_res_first, n, segments_key_ret3, segments_val_ret3,

std::pair<Iterator3, Iterator4> res;
if constexpr (std::is_same_v<std::equal_to<KeyT>, std::decay_t<BinaryPredicate>> &&
std::is_same_v<std::plus<ValT>, std::decay_t<BinaryOperation>>)
{
res = oneapi::dpl::reduce_by_segment(std::forward<Policy>(exec), keys_first, keys_last, vals_first,
key_res_first, val_res_first);
}
else if constexpr (std::is_same_v<std::plus<ValT>, std::decay_t<BinaryOperation>>)
{
res = oneapi::dpl::reduce_by_segment(std::forward<Policy>(exec), keys_first, keys_last, vals_first,
key_res_first, val_res_first, BinaryPredicate());
}
else
{
res = oneapi::dpl::reduce_by_segment(std::forward<Policy>(exec), keys_first, keys_last, vals_first,
key_res_first, val_res_first, BinaryPredicate(), BinaryOperation());
}
size_t segments_key_ret = ::std::distance(key_res_first, res.first);
size_t segments_val_ret = ::std::distance(val_res_first, res.second);
check_values(keys_first, vals_first, key_res_first, val_res_first, n, segments_key_ret, segments_val_ret,
BinaryPredicate(), BinaryOperation());
}

Expand Down Expand Up @@ -368,7 +352,6 @@ main()
test_flag_pred<sycl::usm::alloc::device, class KernelName2, dpl::complex<float>>();
#endif // TEST_DPCPP_BACKEND_PRESENT

run_test<::std::uint64_t, UserBinaryPredicate<::std::uint64_t>, MaxFunctor<::std::uint64_t>>();
run_test<::std::complex<float>, UserBinaryPredicate<::std::complex<float>>, MaxFunctor<::std::complex<float>>>();

run_test<int, ::std::equal_to<int>, ::std::plus<int>>();
Expand Down

0 comments on commit f6a8c40

Please sign in to comment.