Skip to content

Commit

Permalink
[Feature] Pin dgl.graph to the page-locked memory (dmlc#3616)
Browse files Browse the repository at this point in the history
* implement pin_memory/unpin_memory/is_pinned for dgl.graph

* update python docstring

* update c++ docstring

* add test

* fix the broken UnifiedTensor

* eliminate extra context parameter for pin/unpin

* fix linting

* fix typo

* disable new format materialization for pinned graphs

* update python doc for pin_memory_

* fix unit test

* update doc

* change unitgraph and heterograph's PinMemory to in-place

* update comments for NDArray's PinMemory_ and PinData

* update doc

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
yaox12 and VoVAllen authored Jan 21, 2022
1 parent 51651ec commit 40b44a4
Show file tree
Hide file tree
Showing 20 changed files with 513 additions and 25 deletions.
4 changes: 4 additions & 0 deletions docs/source/api/python/dgl.DGLGraph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ Methods for getting or changing the device on which the graph is hosted.

DGLGraph.to
DGLGraph.device
DGLGraph.cpu
DGLGraph.pin_memory_
DGLGraph.unpin_memory_
DGLGraph.is_pinned

Misc
----
Expand Down
31 changes: 31 additions & 0 deletions include/dgl/aten/coo.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,37 @@ struct COOMatrix {
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
row_sorted, col_sorted);
}

/*!
* \brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPUPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
}

/*!
* \brief Unpin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
}
};

///////////////////////// COO routines //////////////////////////
Expand Down
31 changes: 31 additions & 0 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,37 @@ struct CSRMatrix {
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
sorted);
}

/*!
* \brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPUPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
indptr.PinMemory_();
indices.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
}

/*!
* \brief Unpin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
indptr.UnpinMemory_();
indices.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
}
};

///////////////////////// CSR routines //////////////////////////
Expand Down
5 changes: 4 additions & 1 deletion include/dgl/aten/macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@
* // Now XPU is a placeholder for array->ctx.device_type
* DeviceSpecificImplementation<XPU>(...);
* });
*
* We treat pinned memory as normal host memory if we don't want
* to enable CUDA UVA access for this operator
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
if ((val) == kDLCPU || (val) == kDLCPUPinned) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
Expand Down
5 changes: 5 additions & 0 deletions include/dgl/base_heterograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ class BaseHeteroGraph : public runtime::Object {
*/
virtual DLContext Context() const = 0;

/*!
* \brief Check if this graph is pinned.
*/
virtual bool IsPinned() const = 0;

/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
*/
Expand Down
19 changes: 13 additions & 6 deletions include/dgl/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,17 @@ class DeviceAPI {
/*!
* \brief Pin host memory using cudaHostRegister().
*
* \param ctx The context of pinning and mapping.
* \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned.
*/
DGL_DLL virtual void PinData(DGLContext ctx, void* ptr, size_t nbytes);
DGL_DLL virtual void PinData(void* ptr, size_t nbytes);

/*!
* \brief Unpin host memory ussing cudaHostUnregister().
* \brief Unpin host memory using cudaHostUnregister().
*
* \param ctx The context to unmap and unpin.
* \param ptr The host memory pointer to be unpinned.
*/
DGL_DLL virtual void UnpinData(DGLContext ctx, void* ptr);
DGL_DLL virtual void UnpinData(void* ptr);

/*!
* \brief Allocate temporal workspace for backend execution.
Expand Down Expand Up @@ -190,12 +188,21 @@ class DeviceAPI {
DGL_DLL virtual void FreeWorkspace(DGLContext ctx, void* ptr);

/*!
* \brief Get device API base don context.
* \brief Get device API based on context.
* \param ctx The context
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);


/*!
* \brief Get device API based on context.
* \param dev_type The device type
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
DGL_DLL static DeviceAPI* Get(DLDeviceType dev_type, bool allow_missing = false);
};

/*! \brief The device type bigger than this is RPC device */
Expand Down
57 changes: 57 additions & 0 deletions include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,27 @@ class NDArray {
* \brief Return a new array with a copy of the content.
*/
inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief In-place method to pin the current array by calling PinData
* on the underlying DLTensor.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPUPinned: directly return;
* kDLGPU: invalid, will throw an error.
*/
inline void PinMemory_();
/*!
* \brief In-place method to unpin the current array by calling UnpinData
* on the underlying DLTensor.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPUPinned: will be unpinned;
* others: directly return.
*/
inline void UnpinMemory_();
/*!
* \brief Check if the array is pinned.
*/
inline bool IsPinned() const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
Expand Down Expand Up @@ -272,6 +293,27 @@ class NDArray {
DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, DGLStreamHandle stream = nullptr);

/*!
* \brief Function to pin the data of a DLTensor.
* \param tensor The array to be pinned.
* \note Data of the given array will be pinned inplace.
* Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDLCPUPinned: directly return;
* kDLGPU: invalid, will throw an error.
*/
DGL_DLL static void PinData(DLTensor* tensor);

/*!
* \brief Function to unpin the data of a DLTensor.
* \param tensor The array to be unpinned.
* \note Data of the given array will be unpinned inplace.
* Behavior depends on the current context,
* kDLCPUPinned: will be unpinned;
* others: directly return.
*/
DGL_DLL static void UnpinData(DLTensor* tensor);

// internal namespace
struct Internal;
private:
Expand Down Expand Up @@ -431,6 +473,21 @@ inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
return this->CopyTo(dptr->ctx, stream);
}

inline void NDArray::PinMemory_() {
CHECK(data_ != nullptr);
PinData(&(data_->dl_tensor));
}

inline void NDArray::UnpinMemory_() {
CHECK(data_ != nullptr);
UnpinData(&(data_->dl_tensor));
}

inline bool NDArray::IsPinned() const {
CHECK(data_ != nullptr);
return data_->dl_tensor.ctx.device_type == kDLCPUPinned;
}

inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
Expand Down
68 changes: 68 additions & 0 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5458,6 +5458,74 @@ def cpu(self):
"""
return self.to(F.cpu())

def pin_memory_(self):
"""Pin the graph structure to the page-locked memory.
This is an **inplace** method. The graph structure must be on CPU to be pinned.
If the graph struture is already pinned, the function directly returns it.
Materialization of new sparse formats for pinned graphs is not allowed.
To avoid implicit formats materialization during training,
you should create all the needed formats before pinnning.
But cloning and materialization is fine. See the examples below.
Returns
-------
DGLGraph
The pinned graph.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> g = dgl.graph((torch.tensor([1, 0]), torch.tensor([1, 2])))
>>> g.pin_memory_()
Materialization of new sparse formats is not allowed for pinned graphs.
>>> g.create_formats_() # This would raise an error! You should do this before pinning.
Cloning and materializing new formats is allowed. The returned graph is **not** pinned.
>>> g1 = g.formats(['csc'])
>>> assert not g1.is_pinned()
"""
if self._graph.is_pinned():
return self
if F.device_type(self.device) != 'cpu':
raise DGLError("The graph structure must be on CPU to be pinned.")
self._graph.pin_memory_()
return self

def unpin_memory_(self):
"""Unpin the graph structure from the page-locked memory.
This is an **inplace** method.If the graph struture is not pinned,
e.g., on CPU or GPU, the function directly returns it.
Returns
-------
DGLGraph
The unpinned graph.
"""
if not self._graph.is_pinned():
return self
self._graph.unpin_memory_()
return self

def is_pinned(self):
"""Check if the graph structure is pinned to the page-locked memory.
Returns
-------
bool
True if the graph structure is pinned.
"""
return self._graph.is_pinned()

def clone(self):
"""Return a heterograph object that is a clone of current graph.
Expand Down
38 changes: 38 additions & 0 deletions python/dgl/heterograph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,44 @@ def copy_to(self, ctx):
"""
return _CAPI_DGLHeteroCopyTo(self, ctx.device_type, ctx.device_id)

def pin_memory_(self):
"""Pin this graph to the page-locked memory.
NOTE: This is an inplace method.
The graph structure must be on CPU to be pinned.
If the graph struture is already pinned, the function directly returns it.
Returns
-------
HeteroGraphIndex
The pinned graph index.
"""
return _CAPI_DGLHeteroPinMemory_(self)

def unpin_memory_(self):
"""Unpin this graph from the page-locked memory.
NOTE: this is an inplace method.
If the graph struture is not pinned, e.g., on CPU or GPU,
the function directly returns it.
Returns
-------
HeteroGraphIndex
The unpinned graph index.
"""
return _CAPI_DGLHeteroUnpinMemory_(self)

def is_pinned(self):
"""Check if this graph is pinned to the page-locked memory.
Returns
-------
bool
True if the graph is pinned.
"""
return bool(_CAPI_DGLHeteroIsPinned(self))

def shared_memory(self, name, ntypes=None, etypes=None, formats=('coo', 'csr', 'csc')):
"""Return a copy of this graph in shared memory
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/uvm/array_index_select_uvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
int64_t num_feat = 1;
std::vector<int64_t> shape{len};

CHECK_EQ(array->ctx.device_type, kDLCPU);
CHECK_EQ(array->ctx.device_type, kDLCPUPinned);
CHECK_EQ(index->ctx.device_type, kDLGPU);

for (int d = 1; d < array->ndim; ++d) {
Expand Down
8 changes: 4 additions & 4 deletions src/array/uvm_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ namespace aten {

NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA
CHECK_EQ(array->ctx.device_type, kDLCPU) << "Only the CPU device type input "
<< "array supported";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Only the GPU device type input "
<< "index supported";
CHECK_EQ(array->ctx.device_type, kDLCPUPinned)
<< "Only the CPUPinned device type input array is supported";
CHECK_EQ(index->ctx.device_type, kDLGPU)
<< "Only the GPU device type input index is supported";

CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
Expand Down
10 changes: 10 additions & 0 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
hgindex->num_verts_per_type_));
}

void HeteroGraph::PinMemory_() {
for (auto g : relation_graphs_)
g->PinMemory_();
}

void HeteroGraph::UnpinMemory_() {
for (auto g : relation_graphs_)
g->UnpinMemory_();
}

std::string HeteroGraph::SharedMemName() const {
return shared_mem_ ? shared_mem_->GetName() : "";
}
Expand Down
Loading

0 comments on commit 40b44a4

Please sign in to comment.