Skip to content

Commit

Permalink
[Feature] Remove nodes/edges. (dmlc#599)
Browse files Browse the repository at this point in the history
* upd

* upd

* reformat

* upd

* upd

* add test

* fix arange

* fix slight bug

* upd

* trigger

* upd docs

* upd

* upd

* upd

* change subgraph to be raw data wrapper

* upd

* fix test
  • Loading branch information
yzh119 authored Jun 8, 2019
1 parent e7389d7 commit baa1623
Show file tree
Hide file tree
Showing 14 changed files with 444 additions and 130 deletions.
9 changes: 9 additions & 0 deletions docs/source/api/python/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ Querying graph structure
DGLGraph.out_degree
DGLGraph.out_degrees

Removing nodes and edges
------------------------

.. autosummary::
:toctree: ../../generated/

DGLGraph.remove_nodes
DGLGraph.remove_edges

Transforming graph
------------------

Expand Down
2 changes: 1 addition & 1 deletion include/dgl/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class Graph: public GraphInterface {
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids) const override;
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;

/*!
* \brief Return a new graph with all the edges reversed.
Expand Down
2 changes: 1 addition & 1 deletion include/dgl/graph_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ class GraphInterface {
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
virtual Subgraph EdgeSubgraph(IdArray eids) const = 0;
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;

/*!
* \brief Return a new graph with all the edges reversed.
Expand Down
6 changes: 3 additions & 3 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class CSR : public GraphInterface {

Subgraph VertexSubgraph(IdArray vids) const override;

Subgraph EdgeSubgraph(IdArray eids) const override {
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
<< " Please use COO graph instead.";
return {};
Expand Down Expand Up @@ -409,7 +409,7 @@ class COO : public GraphInterface {
return {};
}

Subgraph EdgeSubgraph(IdArray eids) const override;
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;

GraphPtr Reverse() const override {
return Transpose();
Expand Down Expand Up @@ -810,7 +810,7 @@ class ImmutableGraph: public GraphInterface {
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids) const override;
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;

/*!
* \brief Return a new graph with all the edges reversed.
Expand Down
75 changes: 69 additions & 6 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,57 @@ def add_edges(self, u, v, data=None):
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)

def remove_nodes(self, vids):
"""Remove multiple nodes.
Parameters
----------
vids: list, tensor
The id of nodes to remove.
"""
if self.is_readonly:
raise DGLError("remove_nodes is not supported by read-only graph.")
induced_nodes = utils.set_diff(utils.toindex(self.nodes()), utils.toindex(vids))
sgi = self._graph.node_subgraph(induced_nodes)

if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)

if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)

self._graph = sgi.graph

def remove_edges(self, eids):
"""Remove multiple edges.
Parameters
----------
eids: list, tensor
The id of edges to remove.
"""
if self.is_readonly:
raise DGLError("remove_edges is not supported by read-only graph.")
induced_edges = utils.set_diff(
utils.toindex(range(self.number_of_edges())), utils.toindex(eids))
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=True)

if isinstance(self._node_frame, FrameRef):
self._node_frame = FrameRef(Frame(self._node_frame[sgi.induced_nodes]))
else:
self._node_frame = FrameRef(self._node_frame, sgi.induced_nodes)

if isinstance(self._edge_frame, FrameRef):
self._edge_frame = FrameRef(Frame(self._edge_frame[sgi.induced_edges]))
else:
self._edge_frame = FrameRef(self._edge_frame, sgi.induced_edges)

self._graph = sgi.graph

def clear(self):
"""Remove all nodes and edges, as well as their features, from the
graph.
Expand Down Expand Up @@ -2813,7 +2864,7 @@ def subgraph(self, nodes):
from . import subgraph
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
return subgraph.DGLSubGraph(self, sgi)

def subgraphs(self, nodes):
"""Return a list of subgraphs, each induced in the corresponding given
Expand Down Expand Up @@ -2841,17 +2892,20 @@ def subgraphs(self, nodes):
from . import subgraph
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
return [subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
for sgi in sgis]
return [subgraph.DGLSubGraph(self, sgi) for sgi in sgis]

def edge_subgraph(self, edges):
def edge_subgraph(self, edges, preserve_nodes=False):
"""Return the subgraph induced on given edges.
Parameters
----------
edges : list, or iterable
An edge ID array to construct subgraph.
All edges must exist in the subgraph.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
Expand Down Expand Up @@ -2880,6 +2934,15 @@ def edge_subgraph(self, edges):
tensor([0, 1, 4])
>>> SG.parent_eid
tensor([0, 4])
>>> SG = G.edge_subgraph([0, 4], preserve_nodes=True)
>>> SG.nodes()
tensor([0, 1, 2, 3, 4])
>>> SG.edges()
(tensor([0, 4]), tensor([1, 0]))
>>> SG.parent_nid
tensor([0, 1, 2, 3, 4])
>>> SG.parent_eid
tensor([0, 4])
See Also
--------
Expand All @@ -2888,8 +2951,8 @@ def edge_subgraph(self, edges):
"""
from . import subgraph
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi)

def adjacency_matrix_scipy(self, transpose=False, fmt='csr'):
"""Return the scipy adjacency matrix representation of this graph.
Expand Down
68 changes: 20 additions & 48 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ def node_subgraph(self, v):
v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array)
induced_edges = utils.toindex(rst(2))
return SubgraphIndex(rst(0), self, v, induced_edges)
gidx = GraphIndex(rst(0))
return SubgraphIndex(gidx, self, v, induced_edges)

def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
Expand All @@ -536,23 +537,28 @@ def node_subgraphs(self, vs_arr):
gis.append(self.node_subgraph(v))
return gis

def edge_subgraph(self, e):
def edge_subgraph(self, e, preserve_nodes=False):
"""Return the induced edge subgraph.
Parameters
----------
e : utils.Index
The edges.
preserve_nodes : bool
Indicates whether to preserve all nodes or not.
If true, keep the nodes which have no edge connected in the subgraph;
If false, all nodes without edge connected to it would be removed.
Returns
-------
SubgraphIndex
The subgraph index.
"""
e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array)
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array, preserve_nodes)
induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e)
gidx = GraphIndex(rst(0))
return SubgraphIndex(gidx, self, induced_nodes, e)

@utils.cached_member(cache='_cache', prefix='scipy_adj')
def adjacency_matrix_scipy(self, transpose, fmt):
Expand Down Expand Up @@ -870,59 +876,25 @@ def asbits(self, bits):
handle = _CAPI_DGLImmutableGraphAsNumBits(self._handle, int(bits))
return GraphIndex(handle)

class SubgraphIndex(GraphIndex):
"""Graph index for subgraph.
class SubgraphIndex(object):
"""Internal subgraph data structure.
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
graph : GraphIndex
The graph structure of this subgraph.
parent : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle)
self._parent = parent
self._induced_nodes = induced_nodes
self._induced_edges = induced_edges

def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')

def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')

def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.')

@property
def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes

@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def __init__(self, graph, parent, induced_nodes, induced_edges):
self.graph = graph
self.parent = parent
self.induced_nodes = induced_nodes
self.induced_edges = induced_edges

def __getstate__(self):
raise NotImplementedError(
Expand Down
25 changes: 11 additions & 14 deletions python/dgl/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .frame import Frame, FrameRef
from .graph import DGLGraph
from . import utils
from .base import DGLError
from .graph_index import map_to_subgraph_nid

class DGLSubGraph(DGLGraph):
Expand Down Expand Up @@ -32,35 +33,31 @@ class DGLSubGraph(DGLGraph):
----------
parent : DGLGraph
The parent graph
parent_nid : utils.Index
The induced parent node ids in this subgraph.
parent_eid : utils.Index
The induced parent edge ids in this subgraph.
graph_idx : GraphIndex
The graph index.
sgi : SubgraphIndex
Internal subgraph data structure.
shared : bool, optional
Whether the subgraph shares node/edge features with the parent graph.
"""
def __init__(self, parent, parent_nid, parent_eid, graph_idx, shared=False):
super(DGLSubGraph, self).__init__(graph_data=graph_idx,
readonly=graph_idx.is_readonly())
def __init__(self, parent, sgi, shared=False):
super(DGLSubGraph, self).__init__(graph_data=sgi.graph,
readonly=True)
if shared:
raise DGLError('Shared mode is not yet supported.')
self._parent = parent
self._parent_nid = parent_nid
self._parent_eid = parent_eid
self._parent_nid = sgi.induced_nodes
self._parent_eid = sgi.induced_edges

# override APIs
def add_nodes(self, num, data=None):
"""Add nodes. Disabled because BatchedDGLGraph is read-only."""
"""Add nodes. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')

def add_edge(self, u, v, data=None):
"""Add one edge. Disabled because BatchedDGLGraph is read-only."""
"""Add one edge. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')

def add_edges(self, u, v, data=None):
"""Add many edges. Disabled because BatchedDGLGraph is read-only."""
"""Add many edges. Disabled because subgraph is read-only."""
raise DGLError('Readonly graph. Mutation is not allowed.')

@property
Expand Down
23 changes: 23 additions & 0 deletions python/dgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,29 @@ def zero_index(size):
"""
return Index(F.zeros((size,), dtype=F.int64, ctx=F.cpu()))

def set_diff(ar1, ar2):
"""Find the set difference of two index arrays.
Return the unique values in ar1 that are not in ar2.
Parameters
----------
ar1: utils.Index
Input index array.
ar2: utils.Index
Input comparison index array.
Returns
-------
setdiff:
Array of values in ar1 that are not in ar2.
"""
ar1_np = ar1.tonumpy()
ar2_np = ar2.tonumpy()
setdiff = np.setdiff1d(ar1_np, ar2_np)
setdiff = toindex(setdiff)
return setdiff

class LazyDict(Mapping):
"""A readonly dictionary that does not materialize the storage."""
def __init__(self, fn, keys):
Expand Down
Loading

0 comments on commit baa1623

Please sign in to comment.