-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature][Performance] Implement NCCL wrapper for communicating NodeE…
…mbeddings and sparse gradients. (dmlc#2825) * Split NCCL wrapper from sparse optimizer and sparse embedding * Add more unit tests for single node nccl * Fix unit test for tf * Switch to device histogram * Fix histgram issues * Finish migration to histogram * Handle cases with zero send/recieve data * Start on partition object * Get compiling * Updates * Add unit tests * Switch to partition object * Fix linting issues * Rename partition file * Add python doc * Fix python assert and finish doxygen comments * Remove stubs for range based partition to satisfy pylint * Wrap unit test in GPU only * Wrap explicit cuda call in ifdef * Merge with partition.py * update docstrings * Cleanup partition_op * Add Workspace object * Switch to using workspace object * Move last remainder based function out of nccl_api * Add error messages * Update docs with examples * Fix linting erros Co-authored-by: xiang song(charlie.song) <[email protected]>
- Loading branch information
1 parent
0e9259b
commit ae8dbe6
Showing
21 changed files
with
2,070 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Tries to find NCCL headers and libraries. | ||
# | ||
# Usage of this module as follows: | ||
# | ||
# find_package(NCCL) | ||
# | ||
# Variables used by this module, they can change the default behaviour and need | ||
# to be set before calling find_package: | ||
# | ||
# NCCL_ROOT - When set, this path is inspected instead of standard library | ||
# locations as the root of the NCCL installation. | ||
# The environment variable NCCL_ROOT overrides this variable. | ||
# | ||
# This module defines | ||
# Nccl_FOUND, whether nccl has been found | ||
# NCCL_INCLUDE_DIR, directory containing header | ||
# NCCL_LIBRARY, directory containing nccl library | ||
# NCCL_LIB_NAME, nccl library name | ||
# USE_NCCL_LIB_PATH, when set, NCCL_LIBRARY path is also inspected for the | ||
# location of the nccl library. This would disable | ||
# switching between static and shared. | ||
# | ||
# This module assumes that the user has already called find_package(CUDA) | ||
# | ||
# This file is from https://github.com/dmlc/xgboost, with modifications to | ||
# check the version. | ||
|
||
if (NCCL_LIBRARY) | ||
if(NOT USE_NCCL_LIB_PATH) | ||
# Don't cache NCCL_LIBRARY to enable switching between static and shared. | ||
unset(NCCL_LIBRARY CACHE) | ||
endif(NOT USE_NCCL_LIB_PATH) | ||
endif() | ||
|
||
if (BUILD_WITH_SHARED_NCCL) | ||
# libnccl.so | ||
set(NCCL_LIB_NAME nccl) | ||
else () | ||
# libnccl_static.a | ||
set(NCCL_LIB_NAME nccl_static) | ||
endif (BUILD_WITH_SHARED_NCCL) | ||
|
||
find_path(NCCL_INCLUDE_DIR | ||
NAMES nccl.h | ||
PATHS $ENV{NCCL_ROOT}/include ${NCCL_ROOT}/include) | ||
|
||
# make sure it has point to point support | ||
file(STRINGS "${NCCL_INCLUDE_DIR}/nccl.h" NCCL_VERSION_CODE REGEX "^#define[ \t]+NCCL_VERSION_CODE[ \t]+[0-9]+.*$" LIMIT_COUNT 1) | ||
string(REGEX REPLACE "^.*NCCL_VERSION_CODE[ \t]+([0-9]+).*$" "\\1" NCCL_VERSION "${NCCL_VERSION_CODE}") | ||
|
||
|
||
find_library(NCCL_LIBRARY | ||
NAMES ${NCCL_LIB_NAME} | ||
PATHS $ENV{NCCL_ROOT}/lib/ ${NCCL_ROOT}/lib) | ||
|
||
if ("${NCCL_VERSION}" LESS "2700") | ||
message(FATAL_ERROR "Require nccl >= 2700, but found ${NCCL_LIBRARY}==${NCCL_VERSION}") | ||
else() | ||
message(STATUS "Using nccl library: ${NCCL_LIBRARY} ${NCCL_VERSION}") | ||
endif() | ||
|
||
include(FindPackageHandleStandardArgs) | ||
find_package_handle_standard_args(Nccl DEFAULT_MSG | ||
NCCL_INCLUDE_DIR NCCL_LIBRARY) | ||
|
||
mark_as_advanced( | ||
NCCL_INCLUDE_DIR | ||
NCCL_LIBRARY | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
""" CUDA wrappers """ | ||
from . import nccl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
"""API creating NCCL communicators.""" | ||
|
||
from .. import backend as F | ||
from .._ffi.function import _init_api | ||
|
||
_COMM_MODES_MAP = { | ||
'remainder': 0 | ||
} | ||
|
||
class UniqueId(object): | ||
""" Class for allowing python code to create and communicate NCCL Unique | ||
IDs, needed for creating communicators. | ||
""" | ||
def __init__(self, id_str=None): | ||
""" Create an object reference the current NCCL unique id. | ||
""" | ||
if id_str: | ||
if isinstance(id_str, bytes): | ||
id_str = id_str.decode('utf-8') | ||
self._handle = _CAPI_DGLNCCLUniqueIdFromString(id_str) | ||
else: | ||
self._handle = _CAPI_DGLNCCLGetUniqueId() | ||
|
||
def get(self): | ||
""" Get the C-handle for this object. | ||
""" | ||
return self._handle | ||
|
||
def __str__(self): | ||
return _CAPI_DGLNCCLUniqueIdToString(self._handle) | ||
|
||
def __repr__(self): | ||
return "UniqueId[{}]".format(str(self)) | ||
|
||
def __eq__(self, other): | ||
return str(self) == str(other) | ||
|
||
|
||
class Communicator(object): | ||
""" High-level wrapper for NCCL communication. | ||
""" | ||
def __init__(self, size, rank, unique_id): | ||
""" Create a new NCCL communicator. | ||
Parameters | ||
---------- | ||
size : int | ||
The number of processes in the communicator. | ||
rank : int | ||
The rank of the current process in the communicator. | ||
unique_id : NCCLUniqueId | ||
The unique id of the root process (rank=0). | ||
Examples | ||
-------- | ||
>>> from dgl.cuda.nccl import Communicator, UniqueId | ||
The root process will generate a unique NCCL id and communicate it | ||
to the other processes. | ||
>>> uid = UniqueId() | ||
>>> store.set('nccl_root_id', str(uid)) | ||
And all other processes create unique ids from the root processes. | ||
>>> uid = UniqueId(store.get('nccl_root_id')) | ||
Then, all processes should create the communicator. | ||
>>> comm = Communicator(world_size, rank, uid) | ||
""" | ||
assert rank < size, "The rank of a process must be less than the " \ | ||
"size of the communicator." | ||
self._handle = _CAPI_DGLNCCLCreateComm(size, rank, unique_id.get()) | ||
self._rank = rank | ||
self._size = size | ||
|
||
def sparse_all_to_all_push(self, idx, value, partition): | ||
""" Perform an all-to-all-v operation, where by all processors send out | ||
a set of indices and corresponding values. Indices and values, | ||
corresponding to the current process, will copied into the output | ||
arrays. | ||
Parameters | ||
---------- | ||
idx : tensor | ||
The 1D set of indices to send to other processors. | ||
value : tensor | ||
The multi-dimension set of values to send to other processors. | ||
The 0th dimension must match that of `idx`. | ||
partition : NDArrayPartition | ||
The object containing information for assigning indices to | ||
processors. | ||
Returns | ||
------- | ||
tensor | ||
The 1D tensor of the recieved indices. | ||
tensor | ||
The set of recieved values. | ||
Examples | ||
-------- | ||
To perform a sparse_all_to_all_push(), a partition object must be | ||
provided. A partition of a homgeonous graph, where the vertices are | ||
striped across processes can be generated via: | ||
>>> from dgl.partition import NDArrayPartition | ||
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' ) | ||
With this partition, each processor can send values to be associatd | ||
with vertices in the graph. So if we have an array `global_idxs` of all of | ||
the neighbors updated during mini-batch processing, and an array | ||
`global_values` containing the new values associated with the neighbors, | ||
we communicate them to the own processes via: | ||
>>> my_idxs, my_values = comm.sparse_all_to_all_push(global_idxs, global_values, part) | ||
This communication pattern is common when communicating gradient | ||
updates for node embeddings. | ||
Indices the current process owns, do not need to treated specially, | ||
as internally they will be copied to the output array. If we have a | ||
set of indices in process 0 '[0, 3, 8, 9, 10]` and for process 1 | ||
'[0, 2, 4, 5, 8, 8, 9]'. Using a remainder partition will result | ||
indices for processe 0 of '[0, 8, 10, 0, 2, 4, 8, 8]', and for | ||
process 1 of '[3, 9, 5, 9]'. | ||
""" | ||
out_idx, out_value = _CAPI_DGLNCCLSparseAllToAllPush( | ||
self.get(), F.zerocopy_to_dgl_ndarray(idx), | ||
F.zerocopy_to_dgl_ndarray(value), | ||
partition.get()) | ||
return (F.zerocopy_from_dgl_ndarray(out_idx), | ||
F.zerocopy_from_dgl_ndarray(out_value)) | ||
|
||
def sparse_all_to_all_pull(self, req_idx, value, partition): | ||
""" Perform an all-to-all-v operation, where by all processors request | ||
the values corresponding to ther set of indices. | ||
Parameters | ||
---------- | ||
req_idx : IdArray | ||
The set of indices this processor is requesting. | ||
value : NDArray | ||
The multi-dimension set of values that can be requested from | ||
this processor. | ||
partition : NDArrayPartition | ||
The object containing information for assigning indices to | ||
processors. | ||
Returns | ||
------- | ||
tensor | ||
The set of recieved values, corresponding to `req_idx`. | ||
Examples | ||
-------- | ||
To perform a sparse_all_to_all_pull(), a partition object must be | ||
provided. A partition of a homgeonous graph, where the vertices are | ||
striped across processes can be generated via: | ||
>>> from dgl.partition import NDArrayPartition | ||
>>> part = NDArrayPartition(g.num_nodes(), comm.size(), mode='remainder' ) | ||
With this partition, each processor can request values/features | ||
associated with vertices in the graph. So in the case where we have | ||
a set of neighbors 'nbr_idxs' we need features for, and each process | ||
has a tensor 'node_feat' storing the features of nodes it owns in | ||
the partition, the features can be requested via: | ||
>>> nbr_values = comm.sparse_all_to_all_pull(nbr_idxs, node_feat, part) | ||
Then two the arrays 'nbr_idxs' and 'nbr_values' forms the sparse | ||
set of features, where 'nbr_idxs[i]' is the global node id, and | ||
'nbr_values[i]' is the feature vector for that node. This | ||
communication pattern is useful for node features or node | ||
embeddings. | ||
""" | ||
out_value = _CAPI_DGLNCCLSparseAllToAllPull( | ||
self.get(), F.zerocopy_to_dgl_ndarray(req_idx), | ||
F.zerocopy_to_dgl_ndarray(value), | ||
partition.get()) | ||
return F.zerocopy_from_dgl_ndarray(out_value) | ||
|
||
def get(self): | ||
""" Get the C-Handle for this object. | ||
""" | ||
return self._handle | ||
|
||
def rank(self): | ||
""" Get the rank of this process in this communicator. | ||
Returns | ||
------- | ||
int | ||
The rank of this process. | ||
""" | ||
return self._rank | ||
|
||
def size(self): | ||
""" Get the size of this communicator. | ||
Returns | ||
------- | ||
int | ||
The number of processes in this communicator. | ||
""" | ||
return self._size | ||
|
||
_init_api("dgl.cuda.nccl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.