Skip to content

Commit

Permalink
[BUG] Fixing bug in pickling readonly graphs (dmlc#397)
Browse files Browse the repository at this point in the history
* fixing bug in pickling readonly graphs

* multigraph test coverage

* fixing lint (??????)

* another lint

* is_readonly interface
  • Loading branch information
BarclayII authored Feb 19, 2019
1 parent 788d8dd commit 220a1e6
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 17 deletions.
6 changes: 6 additions & 0 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ def is_multigraph(self):
"""
return self._graph.is_multigraph()

@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()

def number_of_edges(self):
"""Return the number of edges in the graph.
Expand Down
19 changes: 10 additions & 9 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, handle=None, multigraph=None, readonly=None):

def __del__(self):
"""Free this graph index object."""
_CAPI_DGLGraphFree(self._handle)
if hasattr(self, '_handle'):
_CAPI_DGLGraphFree(self._handle)

def __getstate__(self):
src, dst, _ = self.edges()
Expand All @@ -46,19 +47,19 @@ def __setstate__(self, state):
"""
n_nodes, multigraph, readonly, src, dst = state

self._cache = {}
self._multigraph = multigraph
self._readonly = readonly
if readonly:
self._readonly = readonly
self._multigraph = multigraph
self.init(src, dst, F.arange(0, len(src)), n_nodes)
self._init(src, dst, utils.toindex(F.arange(0, len(src))), n_nodes)
else:
self._handle = _CAPI_DGLGraphCreateMutable(multigraph)
self._cache = {}

self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)

def init(self, src_ids, dst_ids, edge_ids, num_nodes):
def _init(self, src_ids, dst_ids, edge_ids, num_nodes):
"""The actual init function"""
assert len(src_ids) == len(dst_ids)
assert len(src_ids) == len(edge_ids)
Expand Down Expand Up @@ -746,7 +747,7 @@ def from_networkx(self, nx_graph):
eid = utils.toindex(eid)
src = utils.toindex(src)
dst = utils.toindex(dst)
self.init(src, dst, eid, num_nodes)
self._init(src, dst, eid, num_nodes)


def from_scipy_sparse_matrix(self, adj):
Expand All @@ -763,7 +764,7 @@ def from_scipy_sparse_matrix(self, adj):
src = utils.toindex(adj_coo.row)
dst = utils.toindex(adj_coo.col)
edge_ids = utils.toindex(F.arange(0, len(adj_coo.row)))
self.init(src, dst, edge_ids, num_nodes)
self._init(src, dst, edge_ids, num_nodes)


def from_edge_list(self, elist):
Expand All @@ -786,7 +787,7 @@ def from_edge_list(self, elist):
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
edge_ids = utils.toindex(F.arange(0, len(src)))
self.init(src_ids, dst_ids, edge_ids, num_nodes)
self._init(src_ids, dst_ids, edge_ids, num_nodes)

def line_graph(self, backtracking=True):
"""Return the line graph of this graph.
Expand Down
32 changes: 24 additions & 8 deletions src/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, IdArray edge_ids, size_t num_node
CHECK(IsValidIdArray(edge_ids));
this->AddVertices(num_nodes);
num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0]) << "vectors in COO must have the same length";
CHECK(static_cast<int64_t>(num_edges_) == edge_ids->shape[0]) << "vectors in COO must have the same length";
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
<< "vectors in COO must have the same length";
CHECK(static_cast<int64_t>(num_edges_) == edge_ids->shape[0])
<< "vectors in COO must have the same length";
const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data);
const dgl_id_t *edge_data = static_cast<dgl_id_t*>(edge_ids->data);
Expand Down Expand Up @@ -507,7 +509,10 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
uint64_t num_edges = NumEdges();
uint64_t num_nodes = NumVertices();
if (fmt == "coo") {
IdArray idx = IdArray::Empty({2 * static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data);
if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
Expand All @@ -516,17 +521,28 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
}
IdArray eid = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid;
}
return std::vector<IdArray>{idx, eid};
} else if (fmt == "csr") {
IdArray indptr = IdArray::Empty({static_cast<int64_t>(num_nodes) + 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
IdArray indices = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty({static_cast<int64_t>(num_edges)}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data);
Expand Down
17 changes: 17 additions & 0 deletions tests/compute/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def test_pickling_frame():


def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
Expand Down Expand Up @@ -140,6 +142,21 @@ def test_pickling_graph():
_assert_is_identical(g, new_g)
_assert_is_identical(g2, new_g2)

# readonly graph
g = dgl.DGLGraph([(0, 1), (1, 2)], readonly=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)

# multigraph
g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)

# readonly multigraph
g = dgl.DGLGraph([(0, 1), (0, 1), (1, 2)], multigraph=True, readonly=True)
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)


if __name__ == '__main__':
test_pickling_index()
Expand Down

0 comments on commit 220a1e6

Please sign in to comment.