forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ndarray.cc
505 lines (453 loc) · 17.3 KB
/
ndarray.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
/**
* Copyright (c) 2017-2022 by Contributors
* @file ndarray.cc
* @brief NDArray container infratructure.
*/
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/shared_mem.h>
#include <dgl/runtime/tensordispatch.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/logging.h>
#include <string.h>
#include "runtime_base.h"
namespace dgl {
constexpr DGLDataType DGLDataTypeTraits<int8_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint8_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int16_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype;
#ifdef DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<__half>::dtype;
#if BF16_ENABLED
constexpr DGLDataType DGLDataTypeTraits<__nv_bfloat16>::dtype;
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<float>::dtype;
constexpr DGLDataType DGLDataTypeTraits<double>::dtype;
namespace runtime {
inline void VerifyDataType(DGLDataType dtype) {
CHECK_GE(dtype.lanes, 1);
if (dtype.code == kDGLFloat) {
CHECK_EQ(dtype.bits % 8, 0);
} else {
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}
inline size_t GetDataSize(const DGLArray& arr) {
size_t size = 1;
for (dgl_index_t i = 0; i < arr.ndim; ++i) {
size *= arr.shape[i];
}
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
return size;
}
inline size_t GetDataAlignment(const DGLArray& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
}
void NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) {
using dgl::runtime::NDArray;
if (ptr->manager_ctx != nullptr) {
static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef();
} else if (ptr->mem) {
ptr->mem = nullptr;
} else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it.
if (ptr->pinned_by_dgl_) UnpinContainer(ptr);
if (ptr->pinned_by_pytorch_) {
DeviceAPI::Get(kDGLCUDA)->FreePinnedDataSpace(
&(ptr->pytorch_raw_deleter_));
CHECK(ptr->pytorch_raw_deleter_ == nullptr);
ptr->pinned_by_pytorch_ = false;
ptr->pytorch_ctx_ = nullptr;
} else {
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)
->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
}
delete ptr;
}
NDArray NDArray::Internal::Create(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
VerifyDataType(dtype);
// critical zone
NDArray::Container* data = new NDArray::Container();
data->deleter = DefaultDeleter;
NDArray ret(data);
ret.data_ = data;
// RAII now in effect
// setup shape
data->shape_ = std::move(shape);
data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup stride (this should be optional, but some framework
// does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i + 1] * data->stride_[i + 1];
}
data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype
data->dl_tensor.dtype = dtype;
// setup ctx
data->dl_tensor.ctx = ctx;
return ret;
}
DGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) {
DGLArray* tensor = reinterpret_cast<DGLArray*>(arr.data_);
CHECK(tensor == const_cast<DGLArray*>(arr.operator->()));
arr.data_ = nullptr;
return tensor;
}
size_t NDArray::GetSize() const { return GetDataSize(data_->dl_tensor); }
int64_t NDArray::NumElements() const {
if (data_->dl_tensor.ndim == 0) return 0;
int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i];
}
return size;
}
bool NDArray::IsContiguous() const {
CHECK(data_ != nullptr);
if (data_->dl_tensor.strides == nullptr) return true;
// See https://github.com/dmlc/dgl/issues/2118 and PyTorch's
// compute_contiguous() implementation
int64_t z = 1;
for (int64_t i = data_->dl_tensor.ndim - 1; i >= 0; --i) {
if (data_->dl_tensor.shape[i] != 1) {
if (data_->dl_tensor.strides[i] == z)
z *= data_->dl_tensor.shape[i];
else
return false;
}
}
return true;
}
NDArray NDArray::CreateView(
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset) {
CHECK(data_ != nullptr);
CHECK(IsContiguous()) << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset = this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor);
CHECK_LE(view_size, curr_size)
<< "Tries to create a view that has bigger memory than current one";
// increase ref count
this->data_->IncRef();
ret.data_->manager_ctx = this->data_;
ret.data_->dl_tensor.data =
static_cast<char*>(this->data_->dl_tensor.data) + offset;
return ret;
}
NDArray NDArray::EmptyShared(
const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,
DGLContext ctx, bool is_create) {
NDArray ret = Internal::Create(shape, dtype, ctx);
size_t size = GetDataSize(ret.data_->dl_tensor);
auto mem = std::make_shared<SharedMemory>(name);
if (is_create) {
ret.data_->dl_tensor.data = mem->CreateNew(size);
} else {
ret.data_->dl_tensor.data = mem->Open(size);
}
ret.data_->mem = mem;
return ret;
}
NDArray NDArray::Empty(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
if (size > 0)
ret.data_->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
}
void NDArray::CopyFromTo(DGLArray* from, DGLArray* to) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match";
CHECK(
from->ctx.device_type == to->ctx.device_type ||
from->ctx.device_type == kDGLCPU || to->ctx.device_type == kDGLCPU)
<< "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device
// api manager.
DGLContext ctx = from->ctx.device_type != kDGLCPU ? from->ctx : to->ctx;
// default: local current cuda stream
DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), to->data,
static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,
from->dtype);
}
void NDArray::RecordedCopyFromTo(
DGLArray* from, DGLArray* to, void* pytorch_ctx) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match.";
CHECK(from->ctx.device_type != to->ctx.device_type)
<< "Recoding event is only called for the copy between CPU and GPU.";
CHECK(from->ctx.device_type == kDGLCUDA || to->ctx.device_type == kDGLCUDA)
<< "At least one CUDA ctx needs to be involved.";
DeviceAPI::Get(kDGLCUDA)->RecordedCopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), to->data,
static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,
from->dtype, pytorch_ctx);
}
NDArray NDArray::PinnedEmpty(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
CHECK_EQ(ctx.device_type, kDGLCPU) << "Only NDArray on CPU can be pinned";
NDArray ret = Internal::Create(shape, dtype, ctx);
size_t size = GetDataSize(ret.data_->dl_tensor);
if (size > 0) {
ret.data_->dl_tensor.data = DeviceAPI::Get(kDGLCUDA)->AllocPinnedDataSpace(
size, &(ret.data_->pytorch_ctx_), &(ret.data_->pytorch_raw_deleter_));
CHECK(
ret.data_->pytorch_ctx_ != nullptr &&
ret.data_->pytorch_raw_deleter_ != nullptr)
<< "The allocation failed in PyTorch's CachingHostAllocator. "
<< "The returned context pointer is " << ret.data_->pytorch_ctx_
<< " and the function deleter is " << ret.data_->pytorch_raw_deleter_;
ret.data_->pinned_by_pytorch_ = true;
}
return ret;
}
void NDArray::PinContainer(NDArray::Container* ptr) {
if (IsContainerPinned(ptr)) return;
auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDGLCPU)
<< "Only NDArray on CPU can be pinned";
ptr->pinned_by_dgl_ =
DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor));
}
void NDArray::UnpinContainer(NDArray::Container* ptr) {
auto container_is_pinned = IsContainerPinned(ptr);
// The tensor may be pinned outside of DGL via a different CUDA API,
// so we cannot unpin it with cudaHostUnregister.
CHECK(ptr->pinned_by_dgl_ || !container_is_pinned)
<< "Cannot unpin a tensor that is pinned outside of DGL.";
// 1. not pinned, do nothing
if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it
DeviceAPI::Get(kDGLCUDA)->UnpinData(ptr->dl_tensor.data);
ptr->pinned_by_dgl_ = false;
}
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* tensor_dispatcher = TensorDispatcher::Global();
CHECK(tensor_dispatcher->IsAvailable())
<< "RecordStream only works when TensorAdapter is available.";
CHECK_EQ(tensor->ctx.device_type, kDGLCUDA)
<< "RecordStream only works with GPU tensors.";
tensor_dispatcher->RecordStream(tensor->data, stream, tensor->ctx.device_id);
}
template <typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, ctx);
DeviceAPI::Get(ctx)->CopyDataFromTo(
vec.data(), 0, static_cast<T*>(ret->data), 0, size * sizeof(T),
DGLContext{kDGLCPU, 0}, ctx, dtype);
return ret;
}
NDArray NDArray::CreateFromRaw(
const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,
void* raw, bool auto_free) {
NDArray ret = Internal::Create(shape, dtype, ctx);
ret.data_->dl_tensor.data = raw;
if (!auto_free) ret.data_->deleter = nullptr;
return ret;
}
// export specializations
template NDArray NDArray::FromVector<int32_t>(
const std::vector<int32_t>&, DGLContext);
template NDArray NDArray::FromVector<int64_t>(
const std::vector<int64_t>&, DGLContext);
template NDArray NDArray::FromVector<uint32_t>(
const std::vector<uint32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint64_t>(
const std::vector<uint64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(
const std::vector<float>&, DGLContext);
template NDArray NDArray::FromVector<double>(
const std::vector<double>&, DGLContext);
template <typename T>
std::vector<T> NDArray::ToVector() const {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
CHECK(data_->dl_tensor.ndim == 1)
<< "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
int64_t size = data_->dl_tensor.shape[0];
std::vector<T> vec(size);
const DGLContext& ctx = data_->dl_tensor.ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo(
static_cast<T*>(data_->dl_tensor.data), 0, vec.data(), 0,
size * sizeof(T), ctx, DGLContext{kDGLCPU, 0}, dtype);
return vec;
}
template std::vector<int32_t> NDArray::ToVector<int32_t>() const;
template std::vector<int64_t> NDArray::ToVector<int64_t>() const;
template std::vector<uint32_t> NDArray::ToVector<uint32_t>() const;
template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const;
std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem;
}
bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (ptr->pinned_by_dgl_ || ptr->pinned_by_pytorch_) return true;
auto* tensor = &(ptr->dl_tensor);
// Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDGLCPU) return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned
// memory.
auto device = DeviceAPI::Get(kDGLCUDA, true);
return device && device->IsPinned(tensor->data);
}
void NDArray::Save(dmlc::Stream* strm) const {
auto zc_strm = dynamic_cast<StreamWithBuffer*>(strm);
if (zc_strm) {
zc_strm->PushNDArray(*this);
return;
}
SaveDGLArray(strm, const_cast<DGLArray*>(operator->()));
}
bool NDArray::Load(dmlc::Stream* strm) {
auto zc_strm = dynamic_cast<StreamWithBuffer*>(strm);
if (zc_strm) {
*this = zc_strm->PopNDArray();
return true;
}
uint64_t header, reserved;
CHECK(strm->Read(&header)) << "Invalid DGLArray file format";
CHECK(strm->Read(&reserved)) << "Invalid DGLArray file format";
CHECK(header == kDGLNDArrayMagic) << "Invalid DGLArray file format";
DGLContext ctx;
int ndim;
DGLDataType dtype;
CHECK(strm->Read(&ctx)) << "Invalid DGLArray file format";
CHECK(strm->Read(&ndim)) << "Invalid DGLArray file format";
CHECK(strm->Read(&dtype)) << "Invalid DGLArray file format";
CHECK_EQ(ctx.device_type, kDGLCPU)
<< "Invalid DGLArray context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DGLArray file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size)) << "Invalid DGLArray file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DGLArray file format";
if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DGLArray file format";
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime
} // namespace dgl
using namespace dgl::runtime;
int DGLArrayAlloc(
const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out) {
API_BEGIN();
DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes);
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id;
*out = NDArray::Internal::MoveAsDGLArray(
NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, ctx));
API_END();
}
int DGLArrayAllocSharedMem(
const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out) {
API_BEGIN();
DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes);
std::vector<int64_t> shape_vec(shape, shape + ndim);
NDArray arr = NDArray::EmptyShared(
mem_name, shape_vec, dtype, DGLContext{kDGLCPU, 0}, is_create);
*out = NDArray::Internal::MoveAsDGLArray(arr);
API_END();
}
int DGLArrayFree(DGLArrayHandle handle) {
API_BEGIN();
reinterpret_cast<NDArray::Container*>(handle)->DecRef();
API_END();
}
int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to) {
API_BEGIN();
NDArray::CopyFromTo(from, to);
API_END();
}
int DGLArrayCopyFromBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(
data, 0, handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype);
API_END();
}
int DGLArrayCopyToBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset), data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype);
API_END();
}
int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::PinContainer(nd_container);
API_END();
}
int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::UnpinContainer(nd_container);
API_END();
}
int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream) {
API_BEGIN();
NDArray::RecordStream(handle, stream);
API_END();
}