Skip to content

Commit

Permalink
[Feature] Max readout and consecutive labeling for networkx (dmlc#341)
Browse files Browse the repository at this point in the history
* Max readout and consecutive labeling

* Delete test_readout.py

* Delete test_basics.py

* Test case and fix

* Recover accidentally removed file

* Fix import order

* Fix test case

* Fix

* Fix

* Fix

* Fix

* Fix

* revert

* Fix

* Fix

* Fix
  • Loading branch information
mufeili authored and jermainewang committed Jan 10, 2019
1 parent 707334c commit 3a868eb
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/python/batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ Graph Readout
sum_edges
mean_nodes
mean_edges
max_nodes
max_edges
93 changes: 92 additions & 1 deletion python/dgl/batched_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from . import utils

__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges']
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'max_nodes', 'max_edges']

class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs.
Expand Down Expand Up @@ -725,3 +726,93 @@ def mean_edges(graph, feat, weight=None):
sum_edges
"""
return _mean_on(graph, 'edges', feat, weight)

def _max_on(graph, typestr, feat):
"""Internal function to take elementwise maximum
over node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
Returns
-------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
feat = data[feat]

if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
max_readout_list = []
first = 0
for num_obj in batch_num_objs:
if num_obj == 0:
max_readout_list.append(F.zeros(F.shape(feat)[1:],
F.dtype(feat),
F.context(feat)))
continue
max_readout_list.append(F.max(feat[first:first+num_obj], 0))
first += num_obj
return F.stack(max_readout_list, 0)
else:
return F.max(feat, 0)

def max_nodes(graph, feat):
"""Take elementwise maximum over all the values of node field
:attr:`feat` in :attr:`graph`
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _max_on(graph, 'nodes', feat)

def max_edges(graph, feat):
"""Take elementwise maximum over all the values of edge field
:attr:`feat` in :attr:`graph`
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _max_on(graph, 'edges', feat)
23 changes: 22 additions & 1 deletion python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict

import dgl
import networkx as nx
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
Expand Down Expand Up @@ -1137,7 +1138,9 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
If the node labels of ``nx_graph`` are not consecutive
integers, its nodes will be relabeled using consecutive integers.
The new node ordering will inherit that of ``sorted(nx_graph.nodes())``
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
Expand Down Expand Up @@ -1165,6 +1168,16 @@ def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
[2., 2., 2., 2.],
[1., 1., 1., 1.]])
"""
# Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
# With to_directed we will get a directed version of the original networkx
# graph, with the original nodes, edges and their attributes preserved.
# This is particularly helpful when we are also converting the edge attributes
# as the reversed edges (u, v) will be created with the same attributes as the
# original edges (v, u).
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()

self.clear()
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
Expand Down Expand Up @@ -1194,7 +1207,12 @@ def _batcher(lst):
# None here serves as placeholder to be replaced by feature with
# corresponding edge id
if has_edge_id:
num_edges = self.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True):
if attrs['id'] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be'
' smaller than the number of edges --'
' {}, got {}.'.format(num_edges, attrs['id']))
for key in edge_attrs:
attr_dict[key][attrs['id']] = attrs[key]
else:
Expand All @@ -1204,6 +1222,9 @@ def _batcher(lst):
for key in edge_attrs:
attr_dict[key][eid] = attrs[key]
for attr in edge_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr])

def from_scipy_sparse_matrix(self, spmat):
Expand Down
5 changes: 4 additions & 1 deletion python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,10 @@ def from_networkx(self, nx_graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
if not nx_graph.is_directed():
# to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed()

num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes)
Expand Down
5 changes: 4 additions & 1 deletion python/dgl/immutable_graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ def from_networkx(self, nx_graph):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph))
else:
nx_graph = nx_graph.to_directed()
if not nx_graph.is_directed():
# to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed()

assert nx_graph.number_of_edges() > 0, "can't create an empty immutable graph"

Expand Down
20 changes: 20 additions & 0 deletions tests/compute/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# readonly graph support.
import backend as F
import dgl
import networkx as nx
from dgl import DGLGraph
from collections import defaultdict as ddict

Expand Down Expand Up @@ -242,6 +243,25 @@ def _check_nx_feature(nxg, nf, ef):
edge_feat = F.cat(edge_feat, 0)
assert F.allclose(g.edata['e1'], edge_feat)

# Test converting from a networkx graph whose nodes are
# not labeled with consecutive-integers.
nxg = nx.cycle_graph(5)
nxg.remove_nodes_from([0, 4])
for u in nxg.nodes():
nxg.node[u]['h'] = F.tensor([u])
for u, v, d in nxg.edges(data=True):
d['h'] = F.tensor([u, v])

g = dgl.DGLGraph()
g.from_networkx(nxg, node_attrs=['h'], edge_attrs=['h'])
assert g.number_of_nodes() == 3
assert g.number_of_edges() == 4
assert g.has_edge_between(0, 1)
assert g.has_edge_between(1, 2)
assert F.allclose(g.ndata['h'], F.tensor([[1.], [2.], [3.]]))
assert F.allclose(g.edata['h'], F.tensor([[1., 2.], [1., 2.],
[2., 3.], [2., 3.]]))

def test_batch_send():
g = generate_graph()
def _fmsg(edges):
Expand Down
9 changes: 9 additions & 0 deletions tests/compute/test_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def test_simple_readout():
me1 = F.mean(e1, 0) # edge means
w1 = F.randn((3,))
w2 = F.randn((4,))
max1 = F.max(n1, 0)
max2 = F.max(n2, 0)
maxe1 = F.max(e1, 0)
ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
Expand All @@ -35,20 +38,26 @@ def test_simple_readout():
assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

g = dgl.batch([g1, g2])
s = dgl.sum_nodes(g, 'x')
m = dgl.mean_nodes(g, 'x')
max_bg = dgl.max_nodes(g, 'x')
assert F.allclose(s, F.stack([s1, s2], 0))
assert F.allclose(m, F.stack([m1, m2], 0))
assert F.allclose(max_bg, F.stack([max1, max2], 0))
ws = dgl.sum_nodes(g, 'x', 'w')
wm = dgl.mean_nodes(g, 'x', 'w')
assert F.allclose(ws, F.stack([ws1, ws2], 0))
assert F.allclose(wm, F.stack([wm1, wm2], 0))
s = dgl.sum_edges(g, 'x')
m = dgl.mean_edges(g, 'x')
max_bg_e = dgl.max_edges(g, 'x')
assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))


if __name__ == '__main__':
Expand Down

0 comments on commit 3a868eb

Please sign in to comment.