Skip to content

Commit

Permalink
Partition a graph with HALO nodes (dmlc#1076)
Browse files Browse the repository at this point in the history
* get subgraph with halo.

* add partition function.

* add comment.

* parallel partition.

* fix a compilation error.

* fix lint error.

* address comments.

* add comments.

* fix for TF.
  • Loading branch information
zheng-da authored Dec 23, 2019
1 parent 0d9acc9 commit e890a89
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/dgl/graph_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,15 @@ struct NegSubgraph: public Subgraph {
IdArray tail_nid;
};

/*! \brief Subgraph data structure for halo subgraph */
struct HaloSubgraph: public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes;

/*! \brief Indicate if an edge belongs to the partition. */
IdArray inner_edges;
};

// Define SubgraphRef
DGL_DEFINE_OBJECT_REF(SubgraphRef, Subgraph);

Expand Down
10 changes: 10 additions & 0 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ class GraphOp {
* \return a new immutable bidirected graph.
*/
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);

/*!
* \brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within `num_hops`.
* \param graph The input graph.
* \param nodes The input nodes that form the core of the induced subgraph.
* \param num_hops The number of hops to reach.
* \return the induced subgraph with HALO nodes.
*/
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops);
};

} // namespace dgl
Expand Down
32 changes: 32 additions & 0 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,32 @@ def node_subgraph(self, v):
v_array = v.todgltensor()
return _CAPI_DGLGraphVertexSubgraph(self, v_array)

def node_halo_subgraph(self, v, num_hops):
"""Return an induced subgraph with halo nodes.
Parameters
----------
v : utils.Index
The nodes.
num_hops : int
The number of hops in which a HALO node can be accessed.
Returns
-------
SubgraphIndex
The subgraph index.
DGLTensor
Indicate if a node belongs to a partition.
DGLTensor
Indicate if an edge belongs to a partition.
"""
v_array = v.todgltensor()
subg = _CAPI_DGLGetSubgraphWithHalo(self, v_array, num_hops)
inner_nodes = _CAPI_GetHaloSubgraphInnerNodes(subg)
inner_edges = _CAPI_GetHaloSubgraphInnerEdges(subg)
return subg, inner_nodes, inner_edges

def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
Expand Down Expand Up @@ -1282,4 +1308,10 @@ def create_graph_index(graph_data, multigraph, readonly):
% type(graph_data))
return gidx

def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)

def _get_halo_subgraph_inner_edge(halo_subg):
return _CAPI_GetHaloSubgraphInnerEdges(halo_subg)

_init_api("dgl.graph_index")
42 changes: 42 additions & 0 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from scipy import sparse
from ._ffi.function import _init_api
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from . import backend as F
from .graph_index import from_coo
from .graph_index import _get_halo_subgraph_inner_node
from .graph_index import _get_halo_subgraph_inner_edge
from .batched_graph import BatchedDGLGraph, unbatch
from .convert import graph, bipartite
from . import utils


__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
Expand Down Expand Up @@ -520,4 +524,42 @@ def remove_self_loop(g):
new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
return new_g

def partition_graph_with_halo(g, node_part, num_hops):
'''
This is to partition a graph. Each partition contains HALO nodes
so that we can generate NodeFlow in each partition correctly.
Parameters
------------
g: DGLGraph
The graph to be partitioned
node_part: 1D tensor
Specify which partition a node is assigned to. The length of this tensor
needs to be the same as the number of nodes of the graph. Each element
indicates the partition Id of a node.
num_hops: int
The number of hops a HALO node can be accessed.
Returns
--------
a dict of DGLGraphs
The key is the partition Id and the value is the DGLGraph of the partition.
'''
assert len(node_part) == g.number_of_nodes()
node_part = utils.toindex(node_part)
subgs = _CAPI_DGLPartitionWithHalo(g._graph, node_part.todgltensor(), num_hops)
subg_dict = {}
for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg)
inner_edge = _get_halo_subgraph_inner_edge(subg)
subg = DGLSubGraph(g, subg)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node
inner_edge = F.zerocopy_from_dlpack(inner_edge.to_dlpack())
subg.edata['inner_edge'] = inner_edge
subg_dict[i] = subg
return subg_dict

_init_api("dgl.transform")
182 changes: 182 additions & 0 deletions src/graph/graph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h>
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
Expand Down Expand Up @@ -399,6 +400,187 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph());
}

HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) {
const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data);
const auto id_len = nodes->shape[0];
std::unordered_map<dgl_id_t, bool> all_nodes;
std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops);
for (int64_t i = 0; i < id_len; i++)
all_nodes[nid[i]] = true;
auto orig_nodes = all_nodes;

std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;
std::vector<int> inner_edges;

// When we deal with in-edges, we need to do two things:
// * find the edges inside the partition and the edges between partitions.
// * find the nodes outside the partition that connect the partition.
EdgeArray in_edges = g->InEdges(nodes);
auto src = in_edges.src;
auto dst = in_edges.dst;
auto eid = in_edges.id;
auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
for (int64_t i = 0; i < num_edges; i++) {
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
// We check if the source node is in the original node.
auto it1 = orig_nodes.find(src_data[i]);
inner_edges.push_back(it1 != orig_nodes.end());
// We need to expand only if the node hasn't been seen before.
auto it = all_nodes.find(src_data[i]);
if (it == all_nodes.end()) {
all_nodes[src_data[i]] = false;
outer_nodes[0].push_back(src_data[i]);
}
}

// Now we need to traverse the graph with the in-edges to access nodes
// and edges more hops away.
for (int k = 1; k < num_hops; k++) {
const std::vector<dgl_id_t> &nodes = outer_nodes[k-1];
EdgeArray in_edges = g->InEdges(aten::VecToIdArray(nodes));
auto src = in_edges.src;
auto dst = in_edges.dst;
auto eid = in_edges.id;
auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
for (int64_t i = 0; i < num_edges; i++) {
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
inner_edges.push_back(false);
// If we haven't seen this node.
auto it = all_nodes.find(src_data[i]);
if (it == all_nodes.end()) {
all_nodes[src_data[i]] = false;
outer_nodes[k].push_back(src_data[i]);
}
}
}

// We assign new Ids to the nodes in the subgraph. We ensure that nodes
// with smaller Ids in the original graph will also get smaller Ids in
// the subgraph.

// Move all nodes to a vector.
std::vector<dgl_id_t> old_node_ids;
old_node_ids.reserve(all_nodes.size());
for (auto it = all_nodes.begin(); it != all_nodes.end(); it++) {
old_node_ids.push_back(it->first);
}
std::sort(old_node_ids.begin(), old_node_ids.end());
std::unordered_map<dgl_id_t, dgl_id_t> old2new;
for (size_t i = 0; i < old_node_ids.size(); i++) {
old2new[old_node_ids[i]] = i;
}

num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) {
new_src_data[i] = old2new[edge_src[i]];
new_dst_data[i] = old2new[edge_dst[i]];
}

std::vector<int> inner_nodes(old_node_ids.size());
for (size_t i = 0; i < old_node_ids.size(); i++) {
dgl_id_t old_nid = old_node_ids[i];
inner_nodes[i] = all_nodes[old_nid];
}

GraphPtr subg = ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst);
HaloSubgraph halo_subg;
halo_subg.graph = subg;
halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);
halo_subg.induced_edges = aten::VecToIdArray(edge_eid);
halo_subg.inner_nodes = aten::VecToIdArray<int>(inner_nodes, 32);
halo_subg.inner_edges = aten::VecToIdArray<int>(inner_edges, 32);
return halo_subg;
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0];
IdArray node_parts = args[1];
int num_hops = args[2];

const dgl_id_t *part_data = static_cast<dgl_id_t *>(node_parts->data);
int64_t num_nodes = node_parts->shape[0];
std::unordered_map<int, std::vector<dgl_id_t> > part_map;
for (int64_t i = 0; i < num_nodes; i++) {
dgl_id_t part_id = part_data[i];
auto it = part_map.find(part_id);
if (it == part_map.end()) {
std::vector<dgl_id_t> vec;
vec.push_back(i);
part_map[part_id] = vec;
} else {
it->second.push_back(i);
}
}
std::vector<int> part_ids;
std::vector<std::vector<dgl_id_t> > part_nodes;
int max_part_id = 0;
for (auto it = part_map.begin(); it != part_map.end(); it++) {
max_part_id = std::max(it->first, max_part_id);
part_ids.push_back(it->first);
part_nodes.push_back(it->second);
}
auto graph_ptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr());
// When we construct subgraphs, we only access in-edges.
// We need to make sure the in-CSR exists. Otherwise, we'll
// try to construct in-CSR in openmp for loop, which will lead
// to some unexpected results.
graph_ptr->GetInCSR();
std::vector<std::shared_ptr<HaloSubgraph> > subgs(max_part_id + 1);
int num_partitions = part_nodes.size();
#pragma omp parallel for
for (int i = 0; i < num_partitions; i++) {
auto nodes = aten::VecToIdArray(part_nodes[i]);
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
int part_id = part_ids[i];
subgs[part_id] = subg_ptr;
}
List<SubgraphRef> ret_list;
for (size_t i = 0; i < subgs.size(); i++) {
ret_list.push_back(SubgraphRef(subgs[i]));
}
*rv = ret_list;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGetSubgraphWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0];
IdArray nodes = args[1];
int num_hops = args[2];
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
*rv = SubgraphRef(subg_ptr);
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
*rv = gptr->inner_nodes;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
*rv = gptr->inner_edges;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0];
Expand Down
37 changes: 37 additions & 0 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from scipy import sparse as spsp
import networkx as nx
import numpy as np
import dgl
import dgl.function as fn
import backend as F
from dgl.graph_index import from_scipy_sparse_matrix

D = 5

Expand Down Expand Up @@ -194,6 +196,40 @@ def test_remove_self_loop():
assert F.allclose(new_g.edges()[0], F.tensor([0]))
assert F.allclose(new_g.edges()[1], F.tensor([1]))

def create_large_graph_index(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
return from_scipy_sparse_matrix(spm, True)

def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids)
expand_factor = g.number_of_nodes()
sampler = dgl.contrib.sampling.NeighborSampler(g, batch_size,
expand_factor=expand_factor, num_hops=num_layers,
seed_nodes=node_ids)
return next(iter(sampler))

def test_partition():
g = dgl.DGLGraph(create_large_graph_index(1000), readonly=True)
node_part = np.random.choice(4, g.number_of_nodes())
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2)
for part_id, subg in subgs.items():
node_ids = np.nonzero(node_part == part_id)[0]
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
nf = get_nodeflow(g, node_ids, 2)
lnf = get_nodeflow(subg, lnode_ids, 2)
for i in range(nf.num_layers):
layer_nids1 = F.asnumpy(nf.layer_parent_nid(i))
layer_nids2 = lnf.layer_parent_nid(i)
layer_nids2 = F.asnumpy(F.gather_row(subg.parent_nid, layer_nids2))
assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2))

for i in range(nf.num_blocks):
block_eids1 = F.asnumpy(nf.block_parent_eid(i))
block_eids2 = lnf.block_parent_eid(i)
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2))


if __name__ == '__main__':
Expand All @@ -208,3 +244,4 @@ def test_remove_self_loop():
test_laplacian_lambda_max()
test_remove_self_loop()
test_add_self_loop()
test_partition()
Loading

0 comments on commit e890a89

Please sign in to comment.