-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathndarray_partition.cc
266 lines (225 loc) · 8.37 KB
/
ndarray_partition.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
/**
* Copyright (c) 2021 by Contributors
* @file ndarray_partition.cc
* @brief DGL utilities for working with the partitioned NDArrays
*/
#include "ndarray_partition.h"
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include <memory>
#include <utility>
#include "../c_api_common.h"
#include "partition_op.h"
using namespace dgl::runtime;
namespace dgl {
namespace partition {
NDArrayPartition::NDArrayPartition(
const int64_t array_size, const int num_parts)
: array_size_(array_size), num_parts_(num_parts) {}
int64_t NDArrayPartition::ArraySize() const { return array_size_; }
int NDArrayPartition::NumParts() const { return num_parts_; }
class RemainderPartition : public NDArrayPartition {
public:
RemainderPartition(const int64_t array_size, const int num_parts)
: NDArrayPartition(array_size, num_parts) {
// do nothing
}
std::pair<IdArray, NDArray> GeneratePermutation(
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::GeneratePermutationFromRemainder<kDGLCUDA, IdType>(
ArraySize(), NumParts(), in_idx);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToLocalFromRemainder<kDGLCUDA, IdType>(
NumParts(), in_idx);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
return impl::MapToGlobalFromRemainder<kDGLCUDA, IdType>(
NumParts(), in_idx, part_id);
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
<< ") for "
"partition of size "
<< NumParts() << ".";
return ArraySize() / NumParts() + (part_id < ArraySize() % NumParts());
}
};
class RangePartition : public NDArrayPartition {
public:
RangePartition(const int64_t array_size, const int num_parts, IdArray range)
: NDArrayPartition(array_size, num_parts),
range_(range),
// We also need a copy of the range on the CPU, to compute partition
// sizes. We require the input range on the GPU, as if we have multiple
// GPUs, we can't know which is the proper one to copy the array to, but
// we have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDGLCPU, 0})) {
auto ctx = range->ctx;
if (ctx.device_type != kDGLCUDA) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before "
"creating the partition.";
}
}
std::pair<IdArray, NDArray> GeneratePermutation(
IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input "
"array must be on the same device: "
<< ctx << " vs. " << range_->ctx;
}
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<
kDGLCUDA, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(IdArray in_idx) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToLocalFromRange<kDGLCUDA, IdType, RangeType>(
NumParts(), range_, in_idx);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
IdArray MapToGlobal(IdArray in_idx, const int part_id) const override {
#ifdef DGL_USE_CUDA
auto ctx = in_idx->ctx;
if (ctx.device_type == kDGLCUDA) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToGlobalFromRange<kDGLCUDA, IdType, RangeType>(
NumParts(), range_, in_idx, part_id);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id
<< ") for "
"partition of size "
<< NumParts() << ".";
int64_t part_size = -1;
ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {
const RangeType* const ptr =
static_cast<const RangeType*>(range_cpu_->data);
part_size = ptr[part_id + 1] - ptr[part_id];
});
return part_size;
}
private:
IdArray range_;
IdArray range_cpu_;
};
NDArrayPartitionRef CreatePartitionRemainderBased(
const int64_t array_size, const int num_parts) {
return NDArrayPartitionRef(
std::make_shared<RemainderPartition>(array_size, num_parts));
}
NDArrayPartitionRef CreatePartitionRangeBased(
const int64_t array_size, const int num_parts, IdArray range) {
return NDArrayPartitionRef(
std::make_shared<RangePartition>(array_size, num_parts, range));
}
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0];
int num_parts = args[1];
*rv = CreatePartitionRemainderBased(array_size, num_parts);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased")
.set_body([](DGLArgs args, DGLRetValue* rv) {
const int64_t array_size = args[0];
const int num_parts = args[1];
IdArray range = args[2];
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
int part_id = args[1];
*rv = part->PartSize(part_id);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToLocal")
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
*rv = part->MapToLocal(idxs);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionMapToGlobal")
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
const int part_id = args[2];
*rv = part->MapToGlobal(idxs, part_id);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGeneratePermutation")
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0];
IdArray idxs = args[1];
std::pair<IdArray, NDArray> part_perm = part->GeneratePermutation(idxs);
*rv =
ConvertNDArrayVectorToPackedFunc({part_perm.first, part_perm.second});
});
} // namespace partition
} // namespace dgl