Skip to content

Commit

Permalink
[Feature] Add API to convert graph to bidirected graph (dmlc#598)
Browse files Browse the repository at this point in the history
* to_bidirected

* to_bidirected

* Fix style

* Fix

* Update

* Fix

* Fix

* Update

* Add examples
  • Loading branch information
mufeili authored Jun 10, 2019
1 parent a1513f7 commit fb9dcc5
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/api/python/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ Transform -- Graph Transformation

line_graph
reverse
to_simple_graph
to_bidirected
19 changes: 19 additions & 0 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ class GraphOp {
* \return a new immutable simple graph with no multi-edge.
*/
static ImmutableGraph ToSimpleGraph(const GraphInterface* graph);

/*!
* \brief Convert the graph to a mutable bidirected graph.
*
* If the original graph has m edges for i -> j and n edges for
* j -> i, the new graph will have max(m, n) edges for both
* i -> j and j -> i.
*
* \param graph The input graph.
* \return a new mutable bidirected graph.
*/
static Graph ToBidirectedMutableGraph(const GraphInterface* graph);

/*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable.
* \param graph The input graph.
* \return a new immutable bidirected graph.
*/
static ImmutableGraph ToBidirectedImmutableGraph(const GraphInterface* graph);
};

} // namespace dgl
Expand Down
48 changes: 47 additions & 1 deletion python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .graph_index import GraphIndex
from .batched_graph import BatchedDGLGraph

__all__ = ['line_graph', 'reverse', 'to_simple_graph']
__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected']


def line_graph(g, backtracking=True, shared=False):
Expand Down Expand Up @@ -124,4 +124,50 @@ def to_simple_graph(g):
newgidx = GraphIndex(_CAPI_DGLToSimpleGraph(g._graph.handle))
return DGLGraph(newgidx, readonly=True)

def to_bidirected(g, readonly=True):
"""Convert the graph to a bidirected graph.
The function generates a new graph with no node/edge feature.
If g has m edges for i->j and n edges for j->i, then the
returned graph will have max(m, n) edges for both i->j and j->i.
Parameters
----------
g : DGLGraph
The input graph.
readonly : bool, default to be True
Whether the returned bidirected graph is readonly or not.
Returns
-------
DGLGraph
Examples
--------
The following two examples use PyTorch backend, one for non-multi graph
and one for multi-graph.
>>> # non-multi graph
>>> g = dgl.DGLGraph()
>>> g.add_nodes(2)
>>> g.add_edges([0, 0], [0, 1])
>>> bg1 = dgl.to_bidirected(g)
>>> bg1.edges()
(tensor([0, 1, 0]), tensor([0, 0, 1]))
>>> # multi-graph
>>> g.add_edges([0, 1], [1, 0])
>>> g.edges()
(tensor([0, 0, 0, 1]), tensor([0, 1, 1, 0]))
>>> bg2 = dgl.to_bidirected(g)
>>> bg2.edges()
(tensor([0, 1, 1, 0, 0]), tensor([0, 0, 0, 1, 1]))
"""
if readonly:
newgidx = GraphIndex(_CAPI_DGLToBidirectedImmutableGraph(g._graph.handle))
else:
newgidx = GraphIndex(_CAPI_DGLToBidirectedMutableGraph(g._graph.handle))
return DGLGraph(newgidx)

_init_api("dgl.transform")
18 changes: 18 additions & 0 deletions src/graph/graph_apis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,22 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
*rv = ret;
});

DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
Graph* bgptr = new Graph();
*bgptr = GraphOp::ToBidirectedMutableGraph(ptr);
GraphHandle bghandle = bgptr;
*rv = bghandle;
});

DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
const GraphInterface *ptr = static_cast<const GraphInterface *>(ghandle);
GraphHandle bghandle = GraphOp::ToBidirectedImmutableGraph(ptr).Reset();
*rv = bghandle;
});

} // namespace dgl
71 changes: 71 additions & 0 deletions src/graph/graph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,4 +315,75 @@ ImmutableGraph GraphOp::ToSimpleGraph(const GraphInterface* graph) {
return ImmutableGraph(csr);
}

Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) {
n_e[u][v]++;
}
}

Graph bg;
bg.AddVertices(g->NumVertices());
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
if (new_n_e > 0) {
IdArray us = NewIdArray(new_n_e);
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
std::fill(us_data, us_data + new_n_e, u);
if (u == v) {
bg.AddEdges(us, us);
} else {
IdArray vs = NewIdArray(new_n_e);
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
std::fill(vs_data, vs_data + new_n_e, v);
bg.AddEdges(us, vs);
bg.AddEdges(vs, us);
}
}
}
}
return bg;
}

ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
std::unordered_map<int, std::unordered_map<int, int>> n_e;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
for (const dgl_id_t v : g->SuccVec(u)) {
n_e[u][v]++;
}
}

std::vector<dgl_id_t> srcs, dsts;
for (dgl_id_t u = 0; u < g->NumVertices(); ++u) {
std::unordered_set<dgl_id_t> hashmap;
std::vector<dgl_id_t> nbrs;
for (const dgl_id_t v : g->PredVec(u)) {
if (!hashmap.count(v)) {
nbrs.push_back(v);
hashmap.insert(v);
}
}
for (const dgl_id_t v : g->SuccVec(u)) {
if (!hashmap.count(v)) {
nbrs.push_back(v);
hashmap.insert(v);
}
}
for (const dgl_id_t v : nbrs) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
for (size_t i = 0; i < new_n_e; ++i) {
srcs.push_back(v);
dsts.push_back(u);
}
}
}

IdArray srcs_array = VecToIdArray(srcs);
IdArray dsts_array = VecToIdArray(dsts);
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()));
return ImmutableGraph(coo);
}

} // namespace dgl
18 changes: 18 additions & 0 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,27 @@ def test_simple_graph():
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == set(elist)

def test_bidirected_graph():
def _test(in_readonly, out_readonly):
elist = [(0, 0), (0, 1), (0, 1), (1, 0), (1, 1), (2, 1), (2, 2), (2, 2)]
g = dgl.DGLGraph(elist, readonly=in_readonly)
elist.append((1, 2))
elist = set(elist)
big = dgl.to_bidirected(g, out_readonly)
assert big.number_of_edges() == 10
src, dst = big.edges()
eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
assert eset == set(elist)

_test(True, True)
_test(True, False)
_test(False, True)
_test(False, False)

if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
test_reverse()
test_reverse_shared_frames()
test_simple_graph()
test_bidirected_graph()

0 comments on commit fb9dcc5

Please sign in to comment.