forked from NVIDIA/DALI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add disjoint set primitives. * Add ND connected components labeling. * Optimization: remove degenerate dimensions. Signed-off-by: Michał Zientkiewicz <[email protected]>
- Loading branch information
Showing
6 changed files
with
797 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_COMMON_DISJOINT_SET_H_ | ||
#define DALI_KERNELS_COMMON_DISJOINT_SET_H_ | ||
|
||
#include <type_traits> | ||
#include <utility> | ||
#include "dali/core/span.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
|
||
/** | ||
* @brief Provides set/get operations for group index | ||
* | ||
* @tparam T the actual element of the disjoint set data structure | ||
* @tparam GroupId the index type used by the disjoint set algorithm | ||
*/ | ||
template <typename T, typename GroupId> | ||
struct group_ops { | ||
static_assert(std::is_convertible<T, GroupId>::value && std::is_convertible<GroupId, T>::value, | ||
"This implementation of group ops requires that the element and index are " | ||
"convertible to each other."); | ||
|
||
static inline GroupId get_group(const T &x) { | ||
return x; | ||
} | ||
|
||
static inline GroupId set_group(T &x, GroupId new_id) { | ||
GroupId old = x; | ||
x = new_id; | ||
return old; | ||
} | ||
}; | ||
|
||
/** | ||
* @brief Implements union/find operations | ||
* | ||
* @tparam T the actual element of the disjoint set data structure | ||
* @tparam GroupId the index type used by the disjoint set algorithm | ||
* @tparam Ops provides a way to query and assign the group index to elements of type T | ||
*/ | ||
template <typename T, typename GroupId = int, typename Ops = group_ops<T, GroupId>> | ||
struct disjoint_set { | ||
/** | ||
* @brief Initializes the items by setting their group index to array indices. | ||
*/ | ||
template <typename Container> | ||
void init(Container &&items) { | ||
GroupId start_index = 0; | ||
for (auto &x : items) { | ||
Ops::set_group(x, start_index); | ||
++start_index; | ||
} | ||
} | ||
|
||
/** | ||
* @brief Initializes the items by setting their group index to array indices. | ||
*/ | ||
void init(T *items, GroupId n) { | ||
init(make_span(items, n)); | ||
} | ||
|
||
/** | ||
* @brief Finds the current group index of an element or group | ||
* | ||
* @param x index of an element or a group index of an element | ||
*/ | ||
template <class Container> | ||
static inline GroupId find(Container &&items, GroupId x) { | ||
GroupId x0 = x; | ||
|
||
// find the label | ||
for (;;) { | ||
GroupId g = Ops::get_group(items[x]); | ||
if (g == x) | ||
break; | ||
x = g; | ||
} | ||
|
||
GroupId r = x; | ||
|
||
// assign all intermediate labels to save time in subsequent calls | ||
x = x0; | ||
while (x != Ops::get_group(items[x])) { | ||
x0 = Ops::set_group(items[x], r); | ||
x = x0; | ||
} | ||
|
||
return r; | ||
} | ||
|
||
/** | ||
* @brief Merges elements or groups `x` and `y` | ||
* | ||
* @return Resulting index of the merged group. | ||
* | ||
* @remarks The function combines the groups by finding their group index and | ||
* assigning the lower index to the group with a higher index. | ||
*/ | ||
template <typename Container> | ||
static inline GroupId merge(Container &&items, GroupId x, GroupId y) { | ||
y = find(std::forward<Container>(items), y); | ||
x = find(std::forward<Container>(items), x); | ||
if (x < y) { | ||
Ops::set_group(items[y], x); | ||
return x; | ||
} else if (y < x) { | ||
Ops::set_group(items[x], y); | ||
return y; | ||
} else { | ||
// already merged | ||
return x; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_COMMON_DISJOINT_SET_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gtest/gtest.h> | ||
#include <algorithm> | ||
#include <random> | ||
#include "dali/kernels/common/disjoint_set.h" | ||
#include "dali/core/util.h" | ||
#include "dali/core/format.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
|
||
template <typename T, typename Index, typename Ops, typename RNG> | ||
void random_merge(disjoint_set<T, Index, Ops> ds, | ||
T *data, | ||
unsigned &mask, | ||
RNG &&rng, | ||
std::uniform_int_distribution<Index> &idx_dist, | ||
std::bernoulli_distribution &op_order, | ||
Index prev_idx = -1) { | ||
// the mask is empty - all bits have been used | ||
if (!mask) | ||
return; | ||
Index idx = idx_dist(rng); | ||
while ((mask & (1u << idx)) == 0) // if the index has been used (it's bit in the mask is 0)... | ||
idx = idx_dist(rng); // ...generate a new index and retry | ||
mask &= ~(1u << idx); | ||
|
||
bool merge_first = op_order(rng); | ||
|
||
if (merge_first) { | ||
if (prev_idx >= 0) { | ||
ds.merge(data, idx, prev_idx); | ||
} | ||
} | ||
random_merge(ds, data, mask, rng, idx_dist, op_order, idx); | ||
if (!merge_first) { | ||
if (prev_idx >= 0) { | ||
ds.merge(data, idx, prev_idx); | ||
} | ||
} | ||
} | ||
|
||
template <typename Sequence> | ||
void CheckNoForwardLinks(Sequence &&seq) { | ||
auto b = dali::begin(seq); | ||
auto e = dali::end(seq); | ||
auto i = b; | ||
auto index = *b; | ||
for (++i; i != e; ++i, ++index) { | ||
if (*i < index) { | ||
std::stringstream msg; | ||
msg << "The group index cannot point to an element further on the list; got: "; | ||
for (auto x : seq) | ||
msg << " " << x; | ||
EXPECT_GE(*i, *b) << msg.str(); | ||
} | ||
} | ||
} | ||
|
||
TEST(DisjointSet, BasicTest) { | ||
const int N = 10; | ||
int data[N]; // NOLINT | ||
disjoint_set<int> ds; | ||
ds.init(data); | ||
|
||
for (int i = 0; i < N; i++) { | ||
ASSERT_EQ(data[i], i); | ||
ASSERT_EQ(ds.find(data, i), i); | ||
} | ||
|
||
ds.merge(data, 0, 1); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(ds.find(data, 0), 0); | ||
EXPECT_EQ(ds.find(data, 1), 0); | ||
|
||
ds.merge(data, 3, 2); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(ds.find(data, 2), 2); | ||
EXPECT_EQ(ds.find(data, 3), 2); | ||
|
||
ds.merge(data, 6, 5); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(data[6], 5); | ||
ds.merge(data, 4, 5); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(data[5], 4); | ||
EXPECT_EQ(data[6], 5); | ||
|
||
ds.merge(data, 4, 0); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(data[4], 0); | ||
EXPECT_EQ(ds.find(data, 6), 0); | ||
EXPECT_EQ(data[6], 0) << "`find` should update the entry."; | ||
|
||
ds.merge(data, 8, 9); | ||
CheckNoForwardLinks(data); | ||
ds.merge(data, 7, 9); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(ds.find(data, 8), 7); | ||
EXPECT_EQ(data[8], 7) << "`find` should update the entry."; | ||
ds.merge(data, 6, 7); | ||
CheckNoForwardLinks(data); | ||
EXPECT_EQ(ds.find(data, 9), 0) << "`merge` didn't propagate"; | ||
ds.merge(data, 8, 3); | ||
CheckNoForwardLinks(data); | ||
for (int i = 0; i < N; i++) { | ||
EXPECT_EQ(ds.find(data, i), 0); | ||
EXPECT_EQ(data[i], 0) << "`find` should update the entry."; | ||
} | ||
} | ||
|
||
TEST(DisjointSet, RandomMergeAll) { | ||
const int N = 32; | ||
int data[N]; // NOLINT | ||
disjoint_set<int> ds; | ||
std::mt19937_64 rng(12345); | ||
|
||
for (int iter = 0; iter < 10; iter++) { | ||
ds.init(data); | ||
for (int i = 0; i < N; i++) { | ||
ASSERT_EQ(data[i], i); | ||
} | ||
|
||
std::bernoulli_distribution op_order(0.5); | ||
std::uniform_int_distribution<> idx_dist(0, N-1); | ||
unsigned mask = 0xffffffffu; // 1 bit for each element of the array - initially, all set | ||
|
||
random_merge(ds, data, mask, rng, idx_dist, op_order); | ||
|
||
for (int j = 0; j < N; j++) { | ||
EXPECT_EQ(ds.find(data, j), 0); | ||
} | ||
|
||
for (int j = 0; j < N; j++) { | ||
EXPECT_EQ(data[j], 0); | ||
} | ||
} | ||
} | ||
|
||
} // namespace kernels | ||
} // namespace dali |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Get all the source files and dump test files | ||
collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE) | ||
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE) |
Oops, something went wrong.