Skip to content

Commit

Permalink
[Graph] Add API to convert graph to simple graph (dmlc#587)
Browse files Browse the repository at this point in the history
* to simple

* WIP: multigraph flag

* graph index refactor; pass basic testing

* graph index refactor; pass basic testing

* fix bug in to_simple; pass torch test

* fix mx utest

* fix example

* fix lint

* fix ci

* poke ci

* poke ci

* WIP

* poke ci

* poke ci

* poke ci

* change ci workspace

* poke ci

* poke ci

* poke ci

* poke ci

* delete ci

* use enum for multigraph flag
  • Loading branch information
jermainewang authored Jun 2, 2019
1 parent 372203f commit 01a4cc5
Show file tree
Hide file tree
Showing 16 changed files with 429 additions and 416 deletions.
7 changes: 3 additions & 4 deletions examples/mxnet/sampling/run_store_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
self.graph.from_csr_matrix(dgl.utils.toindex(csr.indptr),
dgl.utils.toindex(csr.indices), "in",
dgl.contrib.graph_store._get_graph_path(graph_name))
self.graph = dgl.graph_index.from_csr_matrix(
dgl.utils.toindex(csr.indptr), dgl.utils.toindex(csr.indices), False,
"in", dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
Expand Down
7 changes: 7 additions & 0 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ class GraphOp {
* \return a expanded Id array.
*/
static IdArray ExpandIds(IdArray ids, IdArray offset);

/*!
* \brief Convert the graph to a simple graph.
* \param graph The input graph.
* \return a new immutable simple graph with no multi-edge.
*/
static ImmutableGraph ToSimpleGraph(const GraphInterface* graph);
};

} // namespace dgl
Expand Down
28 changes: 28 additions & 0 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,17 @@ class CSR : public GraphInterface {
// that has the given number of verts and edges.
CSR(const std::string &shared_mem_name,
int64_t num_vertices, int64_t num_edges, bool is_multigraph);

// Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);

// Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph);

// Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
Expand Down Expand Up @@ -892,6 +900,26 @@ class ImmutableGraph: public GraphInterface {
COOPtr coo_;
};

// inline implementations

template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph): is_multigraph_(is_multigraph) {
indptr_ = NewIdArray(num_vertices + 1);
indices_ = NewIdArray(num_edges);
edge_ids_ = NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data);
for (int64_t i = 0; i < num_vertices + 1; ++i)
*(indptr_data++) = *(indptr_begin++);
for (int64_t i = 0; i < num_edges; ++i) {
*(indices_data++) = *(indices_begin++);
*(edge_ids_data++) = *(edge_ids_begin++);
}
}

} // namespace dgl

#endif // DGL_IMMUTABLE_GRAPH_H_
11 changes: 5 additions & 6 deletions python/dgl/contrib/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .. import backend as F
from ..graph import DGLGraph
from .. import utils
from ..graph_index import GraphIndex, create_graph_index
from ..graph_index import GraphIndex, create_graph_index, from_csr, from_shared_mem_csr_matrix
from .._ffi.ndarray import empty_shared_mem
from .._ffi.function import _init_api
from .. import ndarray as nd
Expand Down Expand Up @@ -309,10 +309,9 @@ def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, po
if isinstance(graph_data, GraphIndex):
graph_idx = graph_data
else:
graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
graph_idx.from_csr_matrix(utils.toindex(indptr), utils.toindex(indices),
edge_dir, _get_graph_path(graph_name))
graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices),
multigraph, edge_dir, _get_graph_path(graph_name))

self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True)
self._num_workers = num_workers
Expand Down Expand Up @@ -541,8 +540,8 @@ def __init__(self, graph_name, port):
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges)

graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name),
num_nodes, num_edges, edge_dir, multigraph)
super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph)
self._init_manager = InitializerManager()

Expand Down
14 changes: 8 additions & 6 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme
from .graph_index import create_graph_index
from . import graph_index
from .runtime import ir, scheduler, Runtime
from . import utils
from .view import NodeView, EdgeView
Expand Down Expand Up @@ -765,7 +765,8 @@ class DGLGraph(DGLBaseGraph):
edge_frame : FrameRef, optional
Edge feature storage.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
Whether the graph would be a multigraph. If none, the flag will be determined
by scanning the whole graph. (default: None)
readonly : bool, optional
Whether the graph structure is read-only (default: False).
Expand Down Expand Up @@ -894,10 +895,11 @@ def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
multigraph=False,
multigraph=None,
readonly=False):
# graph
super(DGLGraph, self).__init__(create_graph_index(graph_data, multigraph, readonly))
gidx = graph_index.create_graph_index(graph_data, multigraph, readonly)
super(DGLGraph, self).__init__(gidx)

# node and edge frame
if node_frame is None:
Expand Down Expand Up @@ -1225,7 +1227,7 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
nx_graph = nx_graph.to_directed()

self.clear()
self._graph.from_networkx(nx_graph)
self._graph = graph_index.from_networkx(nx_graph, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
Expand Down Expand Up @@ -1291,7 +1293,7 @@ def from_scipy_sparse_matrix(self, spmat):
>>> g.from_scipy_sparse_matrix(a)
"""
self.clear()
self._graph.from_scipy_sparse_matrix(spmat)
self._graph = graph_index.from_scipy_sparse_matrix(spmat, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
Expand Down
Loading

0 comments on commit 01a4cc5

Please sign in to comment.