Skip to content

Commit

Permalink
[hotfix] node id validity check (dmlc#1073)
Browse files Browse the repository at this point in the history
* fix

* improve

* fix lint

* upd

* fix

* upd
  • Loading branch information
yzh119 authored Dec 5, 2019
1 parent bc4f435 commit fa0ee46
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 9 deletions.
61 changes: 52 additions & 9 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
'to_networkx',
]

def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
"""Create a graph with one type of nodes and edges.
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
Expand All @@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
card : int, optional
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
the largest node ID plus 1. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
Expand Down Expand Up @@ -101,24 +105,34 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
['follows']
>>> g.canonical_etypes
[('user', 'follows', 'user')]
Check if node ids are within cardinality
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})
"""
if card is not None:
urange, vrange = card, card
else:
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange)
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, list):
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange)
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, ntype, etype, ntype)
elif isinstance(data, nx.Graph):
return create_from_networkx(data, ntype, etype, **kwargs)
else:
raise DGLError('Unsupported graph data type:', type(data))

def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs):
"""Create a bipartite graph.
The result graph is directed and edges must be from ``utype`` nodes
Expand Down Expand Up @@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
card : pair of int, optional
Cardinality (number of nodes in the source and destination group). If None,
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
validate : bool, optional
If True, check if node ids are within cardinality, the check process may take
some time.
If False and card is not None, user would receive a warning. (Default: False)
kwargs : key-word arguments, optional
Other key word arguments. Only comes into effect when we are using a NetworkX
graph. It can consist of:
Expand Down Expand Up @@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
4
>>> g.edges()
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
Check if node ids are within cardinality
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
...
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
>>> g
Graph(num_nodes={'_U': 3, '_V': 4},
num_edges={('_U', '_E', '_V'): 3},
metagraph=[('_U', '_V')])
"""
if utype == vtype:
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
Expand All @@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
urange, vrange = None, None
if isinstance(data, tuple):
u, v = data
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, list):
return create_from_edge_list(data, utype, etype, vtype, urange, vrange)
return create_from_edge_list(data, utype, etype, vtype, urange, vrange, validate)
elif isinstance(data, sp.sparse.spmatrix):
return create_from_scipy(data, utype, etype, vtype)
elif isinstance(data, nx.Graph):
Expand Down Expand Up @@ -667,7 +695,7 @@ def to_homo(G):
# Internal APIs
############################################################

def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a graph from incident nodes with types.
utype could be equal to vtype
Expand All @@ -690,13 +718,22 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
DGLHeteroGraph
"""
u = utils.toindex(u)
v = utils.toindex(v)
if validate:
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
if utype == vtype:
Expand All @@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])

def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False):
"""Internal function to create a heterograph from a list of edge tuples with types.
utype could be equal to vtype
Expand All @@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
vrange : int, optional
The destination node ID range. If None, the value is the
maximum of the destination node IDs in the edge list plus 1. (Default: None)
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
Expand All @@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
u, v = zip(*elist)
u = list(u)
v = list(v)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)

def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
"""Internal function to create a heterograph from a scipy sparse matrix with types.
Expand All @@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
validate : bool, optional
If True, checks if node IDs are within range.
Returns
-------
Expand Down
32 changes: 32 additions & 0 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import itertools
import backend as F
import networkx as nx
from dgl import DGLError


def create_test_heterograph():
# test heterograph from the docstring, plus a user -- wishes -- game relation
Expand Down Expand Up @@ -93,6 +95,36 @@ def test_create():
assert g.number_of_nodes('l1') == 3
assert g.number_of_nodes('l2') == 4

# test if validate flag works
# homo graph
fail = False
try:
g = dgl.graph(
([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),
card=2,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."
# bipartite graph
def _test_validate_bipartite(card):
fail = False
try:
g = dgl.bipartite(
([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]),
card=card,
validate=True
)
except DGLError:
fail = True
finally:
assert fail, "should catch a DGLError because node ID is out of bound."

_test_validate_bipartite((3, 3))
_test_validate_bipartite((2, 4))

def test_query():
g = create_test_heterograph()

Expand Down

0 comments on commit fa0ee46

Please sign in to comment.