-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathstatic_map_ref.cuh
324 lines (290 loc) · 13 KB
/
static_map_ref.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuco/detail/open_addressing/open_addressing_ref_impl.cuh>
#include <cuco/hash_functions.cuh>
#include <cuco/operator.hpp>
#include <cuco/probing_scheme.cuh>
#include <cuco/storage.cuh>
#include <cuco/types.cuh>
#include <cuco/utility/cuda_thread_scope.cuh>
#include <cuda/std/atomic>
namespace cuco {
/**
* @brief Device non-owning "ref" type that can be used in device code to perform arbitrary
* operations defined in `include/cuco/operator.hpp`
*
* @note Concurrent modify and lookup will be supported if both kinds of operators are specified
* during the ref construction.
* @note cuCollections data structures always place the slot keys on the left-hand
* side when invoking the key comparison predicate.
* @note Ref types are trivially-copyable and are intended to be passed by value.
* @note `ProbingScheme::cg_size` indicates how many threads are used to handle one independent
* device operation. `cg_size == 1` uses the scalar (or non-CG) code paths.
*
* @throw If the size of the given key type is larger than 8 bytes
* @throw If the size of the given payload type is larger than 8 bytes
* @throw If the given key type doesn't have unique object representations, i.e.,
* `cuco::bitwise_comparable_v<Key> == false`
* @throw If the given payload type doesn't have unique object representations, i.e.,
* `cuco::bitwise_comparable_v<T> == false`
* @throw If the probing scheme type is not inherited from `cuco::detail::probing_scheme_base`
*
* @tparam Key Type used for keys. Requires `cuco::is_bitwise_comparable_v<Key>` returning true
* @tparam T Type used for mapped values. Requires `cuco::is_bitwise_comparable_v<T>` returning true
* @tparam Scope The scope in which operations will be performed by individual threads.
* @tparam KeyEqual Binary callable type used to compare two keys for equality
* @tparam ProbingScheme Probing scheme (see `include/cuco/probing_scheme.cuh` for options)
* @tparam StorageRef Storage ref type
* @tparam Operators Device operator options defined in `include/cuco/operator.hpp`
*/
template <typename Key,
typename T,
cuda::thread_scope Scope,
typename KeyEqual,
typename ProbingScheme,
typename StorageRef,
typename... Operators>
class static_map_ref
: public detail::operator_impl<
Operators,
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>>... {
/// Flag indicating whether duplicate keys are allowed or not
static constexpr auto allows_duplicates = false;
/// Implementation type
using impl_type = detail::
open_addressing_ref_impl<Key, Scope, KeyEqual, ProbingScheme, StorageRef, allows_duplicates>;
static_assert(sizeof(T) == 4 or sizeof(T) == 8,
"sizeof(mapped_type) must be either 4 bytes or 8 bytes.");
static_assert(
cuco::is_bitwise_comparable_v<Key>,
"Key type must have unique object representations or have been explicitly declared as safe for "
"bitwise comparison via specialization of cuco::is_bitwise_comparable_v<Key>.");
public:
using key_type = Key; ///< Key type
using mapped_type = T; ///< Mapped type
using probing_scheme_type = ProbingScheme; ///< Type of probing scheme
using hasher = typename probing_scheme_type::hasher; ///< Hash function type
using storage_ref_type = StorageRef; ///< Type of storage ref
using bucket_type = typename storage_ref_type::bucket_type; ///< Bucket type
using value_type = typename storage_ref_type::value_type; ///< Storage element type
using extent_type = typename storage_ref_type::extent_type; ///< Extent type
using size_type = typename storage_ref_type::size_type; ///< Probing scheme size type
using key_equal = KeyEqual; ///< Type of key equality binary callable
using iterator = typename storage_ref_type::iterator; ///< Slot iterator type
using const_iterator = typename storage_ref_type::const_iterator; ///< Const slot iterator type
static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size
static constexpr auto bucket_size =
storage_ref_type::bucket_size; ///< Number of elements handled per bucket
static constexpr auto thread_scope = impl_type::thread_scope; ///< CUDA thread scope
/**
* @brief Constructs static_map_ref.
*
* @param empty_key_sentinel Sentinel indicating empty key
* @param empty_value_sentinel Sentinel indicating empty payload
* @param predicate Key equality binary callable
* @param probing_scheme Probing scheme
* @param scope The scope in which operations will be performed
* @param storage_ref Non-owning ref of slot storage
*/
__host__ __device__ explicit constexpr static_map_ref(cuco::empty_key<Key> empty_key_sentinel,
cuco::empty_value<T> empty_value_sentinel,
KeyEqual const& predicate,
ProbingScheme const& probing_scheme,
cuda_thread_scope<Scope> scope,
StorageRef storage_ref) noexcept;
/**
* @brief Constructs static_map_ref.
*
* @param empty_key_sentinel Sentinel indicating empty key
* @param empty_value_sentinel Sentinel indicating empty payload
* @param erased_key_sentinel Sentinel indicating erased key
* @param predicate Key equality binary callable
* @param probing_scheme Probing scheme
* @param scope The scope in which operations will be performed
* @param storage_ref Non-owning ref of slot storage
*/
__host__ __device__ explicit constexpr static_map_ref(cuco::empty_key<Key> empty_key_sentinel,
cuco::empty_value<T> empty_value_sentinel,
cuco::erased_key<Key> erased_key_sentinel,
KeyEqual const& predicate,
ProbingScheme const& probing_scheme,
cuda_thread_scope<Scope> scope,
StorageRef storage_ref) noexcept;
/**
* @brief Operator-agnostic move constructor.
*
* @tparam OtherOperators Operator set of the `other` object
*
* @param other Object to construct `*this` from
*/
template <typename... OtherOperators>
__host__ __device__ explicit constexpr static_map_ref(
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, OtherOperators...>&&
other) noexcept;
/**
* @brief Gets the maximum number of elements the container can hold.
*
* @return The maximum number of elements the container can hold
*/
[[nodiscard]] __host__ __device__ constexpr auto capacity() const noexcept;
/**
* @brief Gets the bucket extent of the current storage.
*
* @return The bucket extent.
*/
[[nodiscard]] __host__ __device__ constexpr extent_type bucket_extent() const noexcept;
/**
* @brief Gets the sentinel value used to represent an empty key slot.
*
* @return The sentinel value used to represent an empty key slot
*/
[[nodiscard]] __host__ __device__ constexpr key_type empty_key_sentinel() const noexcept;
/**
* @brief Gets the sentinel value used to represent an empty key slot.
*
* @return The sentinel value used to represent an empty key slot
*/
[[nodiscard]] __host__ __device__ constexpr mapped_type empty_value_sentinel() const noexcept;
/**
* @brief Gets the sentinel value used to represent an erased key slot.
*
* @return The sentinel value used to represent an erased key slot
*/
[[nodiscard]] __host__ __device__ constexpr key_type erased_key_sentinel() const noexcept;
/**
* @brief Gets the key comparator.
*
* @return The comparator used to compare keys
*/
[[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept;
/**
* @brief Gets the function(s) used to hash keys
*
* @return The function(s) used to hash keys
*/
[[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept;
/**
* @brief Returns a const_iterator to one past the last slot.
*
* @return A const_iterator to one past the last slot
*/
[[nodiscard]] __host__ __device__ constexpr const_iterator end() const noexcept;
/**
* @brief Returns an iterator to one past the last slot.
*
* @return An iterator to one past the last slot
*/
[[nodiscard]] __host__ __device__ constexpr iterator end() noexcept;
/**
* @brief Gets the non-owning storage ref.
*
* @return The non-owning storage ref of the container
*/
[[nodiscard]] __host__ __device__ constexpr auto storage_ref() const noexcept;
/**
* @brief Gets the probing scheme.
*
* @return The probing scheme used for the container
*/
[[nodiscard]] __host__ __device__ constexpr auto probing_scheme() const noexcept;
/**
* @brief Creates a copy of the current non-owning reference using the given operators
*
* @tparam NewOperators List of `cuco::op::*_tag` types
*
* @param ops List of operators, e.g., `cuco::op::insert`
*
* @return Copy of the current device ref
*/
template <typename... NewOperators>
[[nodiscard]] __host__ __device__ constexpr auto rebind_operators(
NewOperators... ops) const noexcept;
/**
* @brief Makes a copy of the current device reference with the given key comparator
*
* @tparam NewKeyEqual The new key equal type
*
* @param key_equal New key comparator
*
* @return Copy of the current device ref
*/
template <typename NewKeyEqual>
[[nodiscard]] __host__ __device__ constexpr auto rebind_key_eq(
NewKeyEqual const& key_equal) const noexcept;
/**
* @brief Makes a copy of the current device reference with the given hasher
*
* @tparam NewHash The new hasher type
*
* @param hash New hasher
*
* @return Copy of the current device ref
*/
template <typename NewHash>
[[nodiscard]] __host__ __device__ constexpr auto rebind_hash_function(NewHash const& hash) const;
/**
* @brief Makes a copy of the current device reference using non-owned memory
*
* This function is intended to be used to create shared memory copies of small static maps,
* although global memory can be used as well.
*
* @note This function synchronizes the group `tile`.
* @note By-default the thread scope of the copy will be the same as the scope of the parent ref.
*
* @tparam CG The type of the cooperative thread group
* @tparam NewScope The thread scope of the newly created device ref
*
* @param tile The ooperative thread group used to copy the data structure
* @param memory_to_use Array large enough to support `capacity` elements. Object does not take
* the ownership of the memory
* @param scope The thread scope of the newly created device ref
*
* @return Copy of the current device ref
*/
template <typename CG, cuda::thread_scope NewScope = thread_scope>
[[nodiscard]] __device__ constexpr auto make_copy(
CG const& tile,
bucket_type* const memory_to_use,
cuda_thread_scope<NewScope> scope = {}) const noexcept;
/**
* @brief Initializes the map storage using the threads in the group `tile`.
*
* @note This function synchronizes the group `tile`.
*
* @tparam CG The type of the cooperative thread group
*
* @param tile The cooperative thread group used to initialize the map
*/
template <typename CG>
__device__ constexpr void initialize(CG const& tile) noexcept;
private:
impl_type impl_; ///< Static map ref implementation
// Mixins need to be friends with this class in order to access private members
template <typename Op, typename Ref>
friend class detail::operator_impl;
// Refs with other operator sets need to be friends too
template <typename Key_,
typename T_,
cuda::thread_scope Scope_,
typename KeyEqual_,
typename ProbingScheme_,
typename StorageRef_,
typename... Operators_>
friend class static_map_ref;
};
} // namespace cuco
#include <cuco/detail/static_map/static_map_ref.inl>