Skip to content

Commit

Permalink
node/edge filtering (dmlc#80)
Browse files Browse the repository at this point in the history
* node/edge filtering

* changing to tensor operations (what did i do???)

* ???
  • Loading branch information
BarclayII authored Oct 17, 2018
1 parent 16da76c commit b2c1c4f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/dgl/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ def unpack(a, split_size_or_sections=None):

def shape(a):
return a.shape

def nonzero_1d(a):
assert a.ndim == 2
return np.nonzero(a)[0]
5 changes: 5 additions & 0 deletions python/dgl/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,8 @@ def zerocopy_from_numpy(np_data):
arr.ctx = get_context(data)
return arr
'''

def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1
return th.nonzero(arr)[:, 0]
55 changes: 55 additions & 0 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import

import networkx as nx
import numpy as np

import dgl
from .base import ALL, is_all, __MSG__, __REPR__
Expand Down Expand Up @@ -1285,3 +1286,57 @@ def line_graph(self, backtracking=True, shared=False):
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)

def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate.
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of node representations by node ID (same as
get_n_repr()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
nodes : container or tensor
The nodes to filter on
Returns
-------
tensor
The filtered nodes
"""
n_repr = self.get_n_repr(nodes)
n_mask = predicate(n_repr)

if is_all(nodes):
return F.nonzero_1d(n_mask)
else:
nodes = F.Tensor(nodes)
return nodes[n_mask]

def filter_edges(self, predicate, edges=ALL):
"""Return a tensor of edge IDs that satisfy the given predicate.
Parameters
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as
get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
Returns
-------
tensor
The filtered edges
"""
e_repr = self.get_e_repr_by_id(edges)
e_mask = predicate(e_repr)

if is_all(edges):
return F.nonzero_1d(e_mask)
else:
edges = F.Tensor(edges)
return edges[e_mask]
39 changes: 39 additions & 0 deletions tests/pytorch/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch as th
import numpy as np
from dgl.graph import DGLGraph

def test_filter():
g = DGLGraph()
g.add_nodes(4)
g.add_edges([0,1,2,3], [1,2,3,0])

n_repr = th.zeros(4, 5)
e_repr = th.zeros(4, 5)
n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1

g.set_n_repr({'a': n_repr})
g.set_e_repr({'a': e_repr})

def predicate(r):
return r['a'].max(1)[0] > 0

# full node filter
n_idx = g.filter_nodes(predicate)
assert set(n_idx.numpy()) == {1, 3}

# partial node filter
n_idx = g.filter_nodes(predicate, [0, 1])
assert set(n_idx.numpy()) == {1}

# full edge filter
e_idx = g.filter_edges(predicate)
assert set(e_idx.numpy()) == {1, 3}

# partial edge filter
e_idx = g.filter_edges(predicate, [0, 1])
assert set(e_idx.numpy()) == {1}


if __name__ == '__main__':
test_filter()

0 comments on commit b2c1c4f

Please sign in to comment.