From 220a1e68fbdb1026c3c3efa71eeb2b9b0a23378d Mon Sep 17 00:00:00 2001 From: "Quan (Andy) Gan" Date: Tue, 19 Feb 2019 00:14:11 -0500 Subject: [PATCH] [BUG] Fixing bug in pickling readonly graphs (#397) * fixing bug in pickling readonly graphs * multigraph test coverage * fixing lint (??????) * another lint * is_readonly interface --- python/dgl/graph.py | 6 ++++++ python/dgl/graph_index.py | 19 ++++++++++--------- src/graph/graph.cc | 32 ++++++++++++++++++++++++-------- tests/compute/test_pickle.py | 17 +++++++++++++++++ 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/python/dgl/graph.py b/python/dgl/graph.py index b8f78031d594..fe9d2f7ae40e 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -422,6 +422,12 @@ def is_multigraph(self): """ return self._graph.is_multigraph() + @property + def is_readonly(self): + """True if the graph is readonly, False otherwise. + """ + return self._graph.is_readonly() + def number_of_edges(self): """Return the number of edges in the graph. diff --git a/python/dgl/graph_index.py b/python/dgl/graph_index.py index 2df27d9fcbb4..3f82be94a817 100644 --- a/python/dgl/graph_index.py +++ b/python/dgl/graph_index.py @@ -30,7 +30,8 @@ def __init__(self, handle=None, multigraph=None, readonly=None): def __del__(self): """Free this graph index object.""" - _CAPI_DGLGraphFree(self._handle) + if hasattr(self, '_handle'): + _CAPI_DGLGraphFree(self._handle) def __getstate__(self): src, dst, _ = self.edges() @@ -46,19 +47,19 @@ def __setstate__(self, state): """ n_nodes, multigraph, readonly, src, dst = state + self._cache = {} + self._multigraph = multigraph + self._readonly = readonly if readonly: - self._readonly = readonly - self._multigraph = multigraph - self.init(src, dst, F.arange(0, len(src)), n_nodes) + self._init(src, dst, utils.toindex(F.arange(0, len(src))), n_nodes) else: self._handle = _CAPI_DGLGraphCreateMutable(multigraph) - self._cache = {} self.clear() self.add_nodes(n_nodes) self.add_edges(src, dst) - def init(self, src_ids, dst_ids, edge_ids, num_nodes): + def _init(self, src_ids, dst_ids, edge_ids, num_nodes): """The actual init function""" assert len(src_ids) == len(dst_ids) assert len(src_ids) == len(edge_ids) @@ -746,7 +747,7 @@ def from_networkx(self, nx_graph): eid = utils.toindex(eid) src = utils.toindex(src) dst = utils.toindex(dst) - self.init(src, dst, eid, num_nodes) + self._init(src, dst, eid, num_nodes) def from_scipy_sparse_matrix(self, adj): @@ -763,7 +764,7 @@ def from_scipy_sparse_matrix(self, adj): src = utils.toindex(adj_coo.row) dst = utils.toindex(adj_coo.col) edge_ids = utils.toindex(F.arange(0, len(adj_coo.row))) - self.init(src, dst, edge_ids, num_nodes) + self._init(src, dst, edge_ids, num_nodes) def from_edge_list(self, elist): @@ -786,7 +787,7 @@ def from_edge_list(self, elist): if min_nodes != 0: raise DGLError('Invalid edge list. Nodes must start from 0.') edge_ids = utils.toindex(F.arange(0, len(src))) - self.init(src_ids, dst_ids, edge_ids, num_nodes) + self._init(src_ids, dst_ids, edge_ids, num_nodes) def line_graph(self, backtracking=True): """Return the line graph of this graph. diff --git a/src/graph/graph.cc b/src/graph/graph.cc index 3928f505fc72..8424263c35f6 100644 --- a/src/graph/graph.cc +++ b/src/graph/graph.cc @@ -20,8 +20,10 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, IdArray edge_ids, size_t num_node CHECK(IsValidIdArray(edge_ids)); this->AddVertices(num_nodes); num_edges_ = src_ids->shape[0]; - CHECK(static_cast(num_edges_) == dst_ids->shape[0]) << "vectors in COO must have the same length"; - CHECK(static_cast(num_edges_) == edge_ids->shape[0]) << "vectors in COO must have the same length"; + CHECK(static_cast(num_edges_) == dst_ids->shape[0]) + << "vectors in COO must have the same length"; + CHECK(static_cast(num_edges_) == edge_ids->shape[0]) + << "vectors in COO must have the same length"; const dgl_id_t *src_data = static_cast(src_ids->data); const dgl_id_t *dst_data = static_cast(dst_ids->data); const dgl_id_t *edge_data = static_cast(edge_ids->data); @@ -507,7 +509,10 @@ std::vector Graph::GetAdj(bool transpose, const std::string &fmt) const uint64_t num_edges = NumEdges(); uint64_t num_nodes = NumVertices(); if (fmt == "coo") { - IdArray idx = IdArray::Empty({2 * static_cast(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); + IdArray idx = IdArray::Empty( + {2 * static_cast(num_edges)}, + DLDataType{kDLInt, 64, 1}, + DLContext{kDLCPU, 0}); int64_t *idx_data = static_cast(idx->data); if (transpose) { std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data); @@ -516,17 +521,28 @@ std::vector Graph::GetAdj(bool transpose, const std::string &fmt) const std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges); } - IdArray eid = IdArray::Empty({static_cast(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); + IdArray eid = IdArray::Empty( + {static_cast(num_edges)}, + DLDataType{kDLInt, 64, 1}, + DLContext{kDLCPU, 0}); int64_t *eid_data = static_cast(eid->data); for (uint64_t eid = 0; eid < num_edges; ++eid) { eid_data[eid] = eid; } return std::vector{idx, eid}; } else if (fmt == "csr") { - IdArray indptr = IdArray::Empty({static_cast(num_nodes) + 1}, DLDataType{kDLInt, 64, 1}, - DLContext{kDLCPU, 0}); - IdArray indices = IdArray::Empty({static_cast(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); - IdArray eid = IdArray::Empty({static_cast(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); + IdArray indptr = IdArray::Empty( + {static_cast(num_nodes) + 1}, + DLDataType{kDLInt, 64, 1}, + DLContext{kDLCPU, 0}); + IdArray indices = IdArray::Empty( + {static_cast(num_edges)}, + DLDataType{kDLInt, 64, 1}, + DLContext{kDLCPU, 0}); + IdArray eid = IdArray::Empty( + {static_cast(num_edges)}, + DLDataType{kDLInt, 64, 1}, + DLContext{kDLCPU, 0}); int64_t *indptr_data = static_cast(indptr->data); int64_t *indices_data = static_cast(indices->data); int64_t *eid_data = static_cast(eid->data); diff --git a/tests/compute/test_pickle.py b/tests/compute/test_pickle.py index a4b142cea3bf..b7f73141fcb5 100644 --- a/tests/compute/test_pickle.py +++ b/tests/compute/test_pickle.py @@ -60,6 +60,8 @@ def test_pickling_frame(): def _assert_is_identical(g, g2): + assert g.is_multigraph == g2.is_multigraph + assert g.is_readonly == g2.is_readonly assert g.number_of_nodes() == g2.number_of_nodes() src, dst = g.all_edges() src2, dst2 = g2.all_edges() @@ -140,6 +142,21 @@ def test_pickling_graph(): _assert_is_identical(g, new_g) _assert_is_identical(g2, new_g2) + # readonly graph + g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True) + new_g = _reconstruct_pickle(g) + _assert_is_identical(g, new_g) + + # multigraph + g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True) + new_g = _reconstruct_pickle(g) + _assert_is_identical(g, new_g) + + # readonly multigraph + g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True, readonly=True) + new_g = _reconstruct_pickle(g) + _assert_is_identical(g, new_g) + if __name__ == '__main__': test_pickling_index()