Skip to content

Commit

Permalink
Connected components (NVIDIA#2640)
Browse files Browse the repository at this point in the history
* Add disjoint set primitives.
* Add ND connected components labeling.
* Optimization: remove degenerate dimensions.

Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient authored Jan 29, 2021
1 parent 0af127a commit ae73c69
Show file tree
Hide file tree
Showing 6 changed files with 797 additions and 1 deletion.
133 changes: 133 additions & 0 deletions dali/kernels/common/disjoint_set.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_COMMON_DISJOINT_SET_H_
#define DALI_KERNELS_COMMON_DISJOINT_SET_H_

#include <type_traits>
#include <utility>
#include "dali/core/span.h"

namespace dali {
namespace kernels {

/**
* @brief Provides set/get operations for group index
*
* @tparam T the actual element of the disjoint set data structure
* @tparam GroupId the index type used by the disjoint set algorithm
*/
template <typename T, typename GroupId>
struct group_ops {
static_assert(std::is_convertible<T, GroupId>::value && std::is_convertible<GroupId, T>::value,
"This implementation of group ops requires that the element and index are "
"convertible to each other.");

static inline GroupId get_group(const T &x) {
return x;
}

static inline GroupId set_group(T &x, GroupId new_id) {
GroupId old = x;
x = new_id;
return old;
}
};

/**
* @brief Implements union/find operations
*
* @tparam T the actual element of the disjoint set data structure
* @tparam GroupId the index type used by the disjoint set algorithm
* @tparam Ops provides a way to query and assign the group index to elements of type T
*/
template <typename T, typename GroupId = int, typename Ops = group_ops<T, GroupId>>
struct disjoint_set {
/**
* @brief Initializes the items by setting their group index to array indices.
*/
template <typename Container>
void init(Container &&items) {
GroupId start_index = 0;
for (auto &x : items) {
Ops::set_group(x, start_index);
++start_index;
}
}

/**
* @brief Initializes the items by setting their group index to array indices.
*/
void init(T *items, GroupId n) {
init(make_span(items, n));
}

/**
* @brief Finds the current group index of an element or group
*
* @param x index of an element or a group index of an element
*/
template <class Container>
static inline GroupId find(Container &&items, GroupId x) {
GroupId x0 = x;

// find the label
for (;;) {
GroupId g = Ops::get_group(items[x]);
if (g == x)
break;
x = g;
}

GroupId r = x;

// assign all intermediate labels to save time in subsequent calls
x = x0;
while (x != Ops::get_group(items[x])) {
x0 = Ops::set_group(items[x], r);
x = x0;
}

return r;
}

/**
* @brief Merges elements or groups `x` and `y`
*
* @return Resulting index of the merged group.
*
* @remarks The function combines the groups by finding their group index and
* assigning the lower index to the group with a higher index.
*/
template <typename Container>
static inline GroupId merge(Container &&items, GroupId x, GroupId y) {
y = find(std::forward<Container>(items), y);
x = find(std::forward<Container>(items), x);
if (x < y) {
Ops::set_group(items[y], x);
return x;
} else if (y < x) {
Ops::set_group(items[x], y);
return y;
} else {
// already merged
return x;
}
}
};

} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_COMMON_DISJOINT_SET_H_
154 changes: 154 additions & 0 deletions dali/kernels/common/disjoint_set_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include "dali/kernels/common/disjoint_set.h"
#include "dali/core/util.h"
#include "dali/core/format.h"

namespace dali {
namespace kernels {

template <typename T, typename Index, typename Ops, typename RNG>
void random_merge(disjoint_set<T, Index, Ops> ds,
T *data,
unsigned &mask,
RNG &&rng,
std::uniform_int_distribution<Index> &idx_dist,
std::bernoulli_distribution &op_order,
Index prev_idx = -1) {
// the mask is empty - all bits have been used
if (!mask)
return;
Index idx = idx_dist(rng);
while ((mask & (1u << idx)) == 0) // if the index has been used (it's bit in the mask is 0)...
idx = idx_dist(rng); // ...generate a new index and retry
mask &= ~(1u << idx);

bool merge_first = op_order(rng);

if (merge_first) {
if (prev_idx >= 0) {
ds.merge(data, idx, prev_idx);
}
}
random_merge(ds, data, mask, rng, idx_dist, op_order, idx);
if (!merge_first) {
if (prev_idx >= 0) {
ds.merge(data, idx, prev_idx);
}
}
}

template <typename Sequence>
void CheckNoForwardLinks(Sequence &&seq) {
auto b = dali::begin(seq);
auto e = dali::end(seq);
auto i = b;
auto index = *b;
for (++i; i != e; ++i, ++index) {
if (*i < index) {
std::stringstream msg;
msg << "The group index cannot point to an element further on the list; got: ";
for (auto x : seq)
msg << " " << x;
EXPECT_GE(*i, *b) << msg.str();
}
}
}

TEST(DisjointSet, BasicTest) {
const int N = 10;
int data[N]; // NOLINT
disjoint_set<int> ds;
ds.init(data);

for (int i = 0; i < N; i++) {
ASSERT_EQ(data[i], i);
ASSERT_EQ(ds.find(data, i), i);
}

ds.merge(data, 0, 1);
CheckNoForwardLinks(data);
EXPECT_EQ(ds.find(data, 0), 0);
EXPECT_EQ(ds.find(data, 1), 0);

ds.merge(data, 3, 2);
CheckNoForwardLinks(data);
EXPECT_EQ(ds.find(data, 2), 2);
EXPECT_EQ(ds.find(data, 3), 2);

ds.merge(data, 6, 5);
CheckNoForwardLinks(data);
EXPECT_EQ(data[6], 5);
ds.merge(data, 4, 5);
CheckNoForwardLinks(data);
EXPECT_EQ(data[5], 4);
EXPECT_EQ(data[6], 5);

ds.merge(data, 4, 0);
CheckNoForwardLinks(data);
EXPECT_EQ(data[4], 0);
EXPECT_EQ(ds.find(data, 6), 0);
EXPECT_EQ(data[6], 0) << "`find` should update the entry.";

ds.merge(data, 8, 9);
CheckNoForwardLinks(data);
ds.merge(data, 7, 9);
CheckNoForwardLinks(data);
EXPECT_EQ(ds.find(data, 8), 7);
EXPECT_EQ(data[8], 7) << "`find` should update the entry.";
ds.merge(data, 6, 7);
CheckNoForwardLinks(data);
EXPECT_EQ(ds.find(data, 9), 0) << "`merge` didn't propagate";
ds.merge(data, 8, 3);
CheckNoForwardLinks(data);
for (int i = 0; i < N; i++) {
EXPECT_EQ(ds.find(data, i), 0);
EXPECT_EQ(data[i], 0) << "`find` should update the entry.";
}
}

TEST(DisjointSet, RandomMergeAll) {
const int N = 32;
int data[N]; // NOLINT
disjoint_set<int> ds;
std::mt19937_64 rng(12345);

for (int iter = 0; iter < 10; iter++) {
ds.init(data);
for (int i = 0; i < N; i++) {
ASSERT_EQ(data[i], i);
}

std::bernoulli_distribution op_order(0.5);
std::uniform_int_distribution<> idx_dist(0, N-1);
unsigned mask = 0xffffffffu; // 1 bit for each element of the array - initially, all set

random_merge(ds, data, mask, rng, idx_dist, op_order);

for (int j = 0; j < N; j++) {
EXPECT_EQ(ds.find(data, j), 0);
}

for (int j = 0; j < N; j++) {
EXPECT_EQ(data[j], 0);
}
}
}

} // namespace kernels
} // namespace dali
3 changes: 2 additions & 1 deletion dali/kernels/imgproc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@ add_subdirectory(color_manipulation)
add_subdirectory(convolution)
add_subdirectory(pointwise)
add_subdirectory(resample)
add_subdirectory(structure)

# Get all the source files and dump test files
collect_headers(DALI_INST_HDRS PARENT_SCOPE)
Expand Down
18 changes: 18 additions & 0 deletions dali/kernels/imgproc/structure/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Get all the source files and dump test files
collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_KERNEL_SRCS PARENT_SCOPE)
collect_test_sources(DALI_KERNEL_TEST_SRCS PARENT_SCOPE)
Loading

0 comments on commit ae73c69

Please sign in to comment.