Skip to content

Commit

Permalink
[Feature] Simplify shared memory graph index (dmlc#1381)
Browse files Browse the repository at this point in the history
* simplify shared memory graph index.

* fix.

* remove edge_dir in SharedMemGraphStore.

* avoid creating shared-mem graph store with from_csr.

* simplify from_csr.

* add comments.

* fix lint.

* remove the test.

* fix compilation error.

* fix a bug.

* fix a bug.

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zheng-da and Ubuntu authored Apr 5, 2020
1 parent 1b152bf commit 7c47d8c
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 160 deletions.
4 changes: 2 additions & 2 deletions examples/mxnet/sampling/run_store_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class GraphData:
def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False,
'in', dgl.contrib.graph_store._get_graph_path(graph_name))
self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False, 'in')
self.graph = self.graph.copyto_shared_mem(dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
Expand Down
22 changes: 12 additions & 10 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -886,13 +886,7 @@ class ImmutableGraph: public GraphInterface {
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);

static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &shared_mem_name);

static ImmutableGraphPtr CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);

/*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO(
Expand All @@ -918,12 +912,10 @@ class ImmutableGraph: public GraphInterface {

/*!
* \brief Copy data to shared memory.
* \param edge_dir the graph of the specific edge direction to be copied.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
static ImmutableGraphPtr CopyToSharedMem(
ImmutableGraphPtr g, const std::string &edge_dir, const std::string &name);
static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name);

/*!
* \brief Convert the graph to use the given number of bits for storage.
Expand Down Expand Up @@ -952,6 +944,14 @@ class ImmutableGraph: public GraphInterface {
GetOutCSR()->SortCSR();
}

bool HasInCSR() const {
return in_csr_ != NULL;
}

bool HasOutCSR() const {
return out_csr_ != NULL;
}

/*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const;

Expand Down Expand Up @@ -995,6 +995,8 @@ class ImmutableGraph: public GraphInterface {
// The name of shared memory for this graph.
// If it's empty, the graph isn't stored in shared memory.
std::string shared_mem_name_;
// We serialize the metadata of the graph index here for shared memory.
NDArray serialized_shared_meta_;
};

// inline implementations
Expand Down
53 changes: 12 additions & 41 deletions python/dgl/contrib/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .. import backend as F
from ..graph import DGLGraph
from .. import utils
from ..graph_index import GraphIndex, create_graph_index, from_csr, from_shared_mem_csr_matrix
from ..graph_index import GraphIndex, create_graph_index, from_shared_mem_graph_index
from .._ffi.ndarray import empty_shared_mem
from .._ffi.function import _init_api
from .. import ndarray as nd
Expand All @@ -25,9 +25,6 @@ def _get_ndata_path(graph_name, ndata_name):
def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name

def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name

def _get_graph_path(graph_name):
return "/" + graph_name

Expand Down Expand Up @@ -118,21 +115,6 @@ def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})

def _to_csr(graph_data, edge_dir, multigraph=None):
try:
indptr = graph_data.indptr
indices = graph_data.indices
return indptr, indices
except:
if isinstance(graph_data, scipy.sparse.spmatrix):
csr = graph_data.tocsr()
return csr.indptr, csr.indices
else:
idx = create_graph_index(graph_data=graph_data, readonly=True)
transpose = (edge_dir != 'in')
csr = idx.adjacency_matrix_scipy(transpose, 'csr')
return csr.indptr, csr.indices

class Barrier(object):
""" A barrier in the KVStore server used for one synchronization.
Expand Down Expand Up @@ -305,8 +287,6 @@ class SharedMemoryStoreServer(object):
----------
graph_data : graph data
Data to initialize graph.
edge_dir : string
the edge direction for the graph structure ("in" or "out")
graph_name : string
Define the name of the graph, so the client can use the name to access the graph.
multigraph : bool, optional
Expand All @@ -317,27 +297,23 @@ class SharedMemoryStoreServer(object):
port : int
The port that the server listens to.
"""
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
def __init__(self, graph_data, graph_name, multigraph, num_workers, port):
self.server = None
if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")

if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, readonly=True)
graph_data = graph_data.copyto_shared_mem(_get_graph_path(graph_name))
elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, readonly=True)
graph_data = graph_data._graph.copyto_shared_mem(_get_graph_path(graph_name))
else:
indptr, indices = _to_csr(graph_data, edge_dir)
graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices),
edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, readonly=True)
graph_data = create_graph_index(graph_data, readonly=True)
graph_data = graph_data.copyto_shared_mem(_get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, readonly=True)

self._num_workers = num_workers
self._graph_name = graph_name
self._edge_dir = edge_dir
self._registered_nworkers = 0

self._barrier = BarrierManager(num_workers)
Expand All @@ -358,8 +334,7 @@ def get_graph_info(graph_name):
assert graph_name == self._graph_name
# if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \
True, edge_dir
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges())

# RPC command: initialize node embedding in the server.
def init_ndata(init, ndata_name, shape, dtype):
Expand Down Expand Up @@ -560,11 +535,10 @@ def __init__(self, graph_name, port):
self._worker_id, self._num_workers = self.proxy.register(graph_name)
if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, _, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges)

graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name),
num_nodes, num_edges, edge_dir)
graph_idx = from_shared_mem_graph_index(_get_graph_path(graph_name))
super(SharedMemoryDGLGraph, self).__init__(graph_idx)
self._init_manager = InitializerManager()

Expand Down Expand Up @@ -1059,7 +1033,7 @@ def destroy(self):


def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph=None, edge_dir='in', port=8000):
multigraph=None, port=8000):
"""Create the graph store server.
The server loads graph structure and node embeddings and edge embeddings.
Expand Down Expand Up @@ -1092,9 +1066,6 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph : bool, optional
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True)
edge_dir : string
the edge direction for the graph structure. The supported option is
"in" and "out".
port : int
The port that the server listens to.
Expand All @@ -1106,7 +1077,7 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
if multigraph is not None:
dgl_warning("multigraph is deprecated." \
"DGL treat all graphs as multigraph by default.")
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, None,
return SharedMemoryStoreServer(graph_data, graph_name, None,
num_workers, port)

def create_graph_from_store(graph_name, store_type, port=8000):
Expand Down
29 changes: 6 additions & 23 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,15 +881,13 @@ def copy_to(self, ctx):
"""
return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id)

def copyto_shared_mem(self, edge_dir, shared_mem_name):
def copyto_shared_mem(self, shared_mem_name):
"""Copy this immutable graph index to shared memory.
NOTE: this method only works for immutable graph index
Parameters
----------
edge_dir : string
Indicate which CSR should copy ("in", "out", "both").
shared_mem_name : string
The name of the shared memory.
Expand All @@ -898,7 +896,7 @@ def copyto_shared_mem(self, edge_dir, shared_mem_name):
GraphIndex
The graph index on the given device context.
"""
return _CAPI_DGLImmutableGraphCopyToSharedMem(self, edge_dir, shared_mem_name)
return _CAPI_DGLImmutableGraphCopyToSharedMem(self, shared_mem_name)

def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64).
Expand Down Expand Up @@ -1017,8 +1015,7 @@ def from_coo(num_nodes, src, dst, readonly):
gidx.add_edges(src, dst)
return gidx

def from_csr(indptr, indices,
direction, shared_mem_name=""):
def from_csr(indptr, indices, direction):
"""Load a graph from CSR arrays.
Parameters
Expand All @@ -1029,38 +1026,24 @@ def from_csr(indptr, indices,
column index array in the CSR format
direction : str
the edge direction. Either "in" or "out".
shared_mem_name : str
the name of shared memory
"""
indptr = utils.toindex(indptr)
indices = utils.toindex(indices)
gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(),
indices.todgltensor(),
shared_mem_name,
direction)
return gidx

def from_shared_mem_csr_matrix(shared_mem_name,
num_nodes, num_edges, edge_dir):
"""Load a graph from the shared memory in the CSR format.
def from_shared_mem_graph_index(shared_mem_name):
"""Load a graph index from the shared memory.
Parameters
----------
shared_mem_name : string
the name of shared memory
num_nodes : int
the number of nodes
num_edges : int
the number of edges
edge_dir : string
the edge direction. The supported option is "in" and "out".
"""
gidx = _CAPI_DGLGraphCSRCreateMMap(
shared_mem_name,
int(num_nodes), int(num_edges),
edge_dir)
return gidx
return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)

def from_networkx(nx_graph, readonly):
"""Convert from networkx graph.
Expand Down
16 changes: 3 additions & 13 deletions src/graph/graph_apis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0];
const IdArray indices = args[1];
const std::string shared_mem_name = args[2];
const std::string edge_dir = args[3];
const std::string edge_dir = args[2];

IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
if (shared_mem_name.empty()) {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, edge_dir, shared_mem_name));
}
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string shared_mem_name = args[0];
const int64_t num_vertices = args[1];
const int64_t num_edges = args[2];
const std::string edge_dir = args[3];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, edge_dir));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
Expand Down
Loading

0 comments on commit 7c47d8c

Please sign in to comment.