Skip to content

Commit

Permalink
[Feature] add metis partitioning to DGL (dmlc#1308)
Browse files Browse the repository at this point in the history
* add metis.

* add test.

* construct partition id.

* link to METIS github repo.

* update metis.

* add a tool for partitioning a graph.

* update metis.

* update.

* update.

* fix metis.

* fix lint

* fix indent.

* another way of building metis.

* disable metis in windows.

* test windows

* fix.

* disable metis for windows properly.

* fix for tensorflow.

* skip test for gpu.

* make graph symmetric

* address comments.

* more comments.

* fix compile

* fix a bug.

* add test.

* change the default #hops of HALO nodes.

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zheng-da and Ubuntu authored Mar 8, 2020
1 parent a9520f7 commit 0e153c4
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
[submodule "third_party/minigun"]
path = third_party/minigun
url = https://github.com/jermainewang/minigun.git
[submodule "third_party/METIS"]
path = third_party/METIS
url = https://github.com/KarypisLab/METIS.git
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ endif(USE_CUDA)
# include directories
include_directories("include")
include_directories("third_party/dlpack/include")
include_directories("third_party/METIS/include/")
include_directories("third_party/dmlc-core/include")
include_directories("third_party/minigun/minigun")
include_directories("third_party/minigun/third_party/moderngpu/src")
Expand Down Expand Up @@ -85,6 +86,10 @@ if(USE_OPENMP)
endif(OPENMP_FOUND)
endif(USE_OPENMP)

# To compile METIS correct for DGL.
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DIDXTYPEWIDTH=64 -DREALTYPEWIDTH=32")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIDXTYPEWIDTH=64 -DREALTYPEWIDTH=32")

# configure minigun
add_definitions(-DENABLE_PARTIAL_FRONTIER=0) # disable minigun partial frontier compile
# Source file lists
Expand Down Expand Up @@ -124,6 +129,15 @@ add_subdirectory("third_party/dmlc-core")
list(APPEND DGL_LINKER_LIBS dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test

if(NOT MSVC)
# Compile METIS
set(GKLIB_PATH "third_party/METIS/GKlib")
include(${GKLIB_PATH}/GKlibSystem.cmake)
include_directories(${GKLIB_PATH})
add_subdirectory("third_party/METIS/libmetis/")
list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC)

target_link_libraries(dgl ${DGL_LINKER_LIBS} ${DGL_RUNTIME_LINKER_LIBS})

# Installation rules
Expand Down
9 changes: 9 additions & 0 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ class GraphOp {
* \return the induced subgraph with HALO nodes.
*/
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops);

/*!
* \brief Partition a graph with Metis.
* The partitioning algorithm assigns each vertex to a partition.
* \param graph The input graph
* \param k The number of partitions.
* \return The partition assignments of all vertices.
*/
static IdArray MetisPartition(GraphPtr graph, int32_t k);
};

} // namespace dgl
Expand Down
40 changes: 40 additions & 0 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,46 @@ def partition_graph_with_halo(g, node_part, num_hops):
subg_dict[i] = subg
return subg_dict

def metis_partition(g, k, extra_cached_hops=0):
'''
This is to partition a graph with Metis partitioning.
Metis assigns vertices to partitions. This API constructs graphs with the vertices assigned
to the partitions and their incoming edges.
The partitioned graph is stored in DGLGraph. The DGLGraph has the `part_id`
node data that indicates the partition a node belongs to.
Parameters
------------
g: DGLGraph
The graph to be partitioned
k: int
The number of partitions.
extra_cached_hops: int
The number of hops a HALO node can be accessed.
Returns
--------
a dict of DGLGraphs
The key is the partition Id and the value is the DGLGraph of the partition.
'''
# METIS works only on symmetric graphs.
g = to_bidirected(g, readonly=True)
node_part = _CAPI_DGLMetisPartition(g._graph, k)
if len(node_part) == 0:
return None

node_part = utils.toindex(node_part)
parts = partition_graph_with_halo(g, node_part, extra_cached_hops)
node_part = node_part.tousertensor()
for part_id in parts:
part = parts[part_id]
part.ndata['part_id'] = F.gather_row(node_part, part.parent_nid)
return parts

def compact_graphs(graphs, always_preserve=None):
"""Given a list of graphs with the same set of nodes, find and eliminate the common
isolated nodes across all graphs.
Expand Down
13 changes: 8 additions & 5 deletions src/graph/graph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,17 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
for (int64_t i = 0; i < num_edges; i++) {
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
// We check if the source node is in the original node.
auto it1 = orig_nodes.find(src_data[i]);
inner_edges.push_back(it1 != orig_nodes.end());
if (it1 != orig_nodes.end() || num_hops > 0) {
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
inner_edges.push_back(it1 != orig_nodes.end());
}
// We need to expand only if the node hasn't been seen before.
auto it = all_nodes.find(src_data[i]);
if (it == all_nodes.end()) {
if (it == all_nodes.end() && num_hops > 0) {
all_nodes[src_data[i]] = false;
outer_nodes[0].push_back(src_data[i]);
}
Expand Down Expand Up @@ -653,4 +655,5 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
});


} // namespace dgl
86 changes: 86 additions & 0 deletions src/graph/metis_partition.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*!
* Copyright (c) 2020 by Contributors
* \file graph/metis_partition.cc
* \brief Call Metis partitioning
*/

#include <metis.h>
#include <dgl/graph_op.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"

using namespace dgl::runtime;

namespace dgl {

#if !defined(_WIN32)

IdArray GraphOp::MetisPartition(GraphPtr g, int k) {
// The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(dgl_id_t));
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);
CHECK(ig) << "The input graph must be an immutable graph.";
// This is a symmetric graph, so in-csr and out-csr are the same.
const auto mat = ig->GetInCSR()->ToCSRMatrix();

idx_t nvtxs = g->NumVertices();
idx_t ncon = 1; // # balacing constraints.
idx_t *xadj = static_cast<idx_t*>(mat.indptr->data);
idx_t *adjncy = static_cast<idx_t*>(mat.indices->data);
idx_t nparts = k;
IdArray part_arr = aten::NewIdArray(nvtxs);
idx_t objval = 0;
idx_t *part = static_cast<idx_t*>(part_arr->data);
int ret = METIS_PartGraphKway(&nvtxs, // The number of vertices
&ncon, // The number of balancing constraints.
xadj, // indptr
adjncy, // indices
NULL, // the weights of the vertices
NULL, // The size of the vertices for computing
// the total communication volume
NULL, // The weights of the edges
&nparts, // The number of partitions.
NULL, // the desired weight for each partition and constraint
NULL, // the allowed load imbalance tolerance
NULL, // the array of options
&objval, // the edge-cut or the total communication volume of
// the partitioning solution
part);
LOG(INFO) << "Partition a graph with " << g->NumVertices()
<< " nodes and " << g->NumEdges()
<< " edges into " << k
<< " parts and get " << objval << " edge cuts";
switch (ret) {
case METIS_OK:
return part_arr;
case METIS_ERROR_INPUT:
LOG(FATAL) << "Error in Metis partitioning: input error";
case METIS_ERROR_MEMORY:
LOG(FATAL) << "Error in Metis partitioning: cannot allocate memory";
default:
LOG(FATAL) << "Error in Metis partitioning: other errors";
}
// return an array of 0 elements to indicate the error.
return aten::NullArray();
}

DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int k = args[1];
*rv = GraphOp::MetisPartition(g.sptr(), k);
});

#else

DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int k = args[1];
LOG(WARNING) << "DGL doesn't support METIS partitioning in Windows";
*rv = aten::NullArray();
});

#endif // !defined(_WIN32)

} // namespace dgl
31 changes: 30 additions & 1 deletion tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def get_nodeflow(g, node_ids, num_layers):
seed_nodes=node_ids)
return next(iter(sampler))

def test_partition():
def test_partition_with_halo():
g = dgl.DGLGraph(create_large_graph_index(1000), readonly=True)
node_part = np.random.choice(4, g.number_of_nodes())
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2)
Expand All @@ -232,6 +232,35 @@ def test_partition():
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2))

@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_metis_partition():
g = dgl.DGLGraph(create_large_graph_index(1000), readonly=True)
subgs = dgl.transform.metis_partition(g, 4, 0)
num_inner_nodes = 0
num_inner_edges = 0
if subgs is not None:
for part_id, subg in subgs.items():
assert np.all(F.asnumpy(subg.ndata['inner_node']) == 1)
assert np.all(F.asnumpy(subg.edata['inner_edge']) == 1)
assert np.all(F.asnumpy(subg.ndata['part_id']) == part_id)
num_inner_nodes += subg.number_of_nodes()
num_inner_edges += subg.number_of_edges()
assert num_inner_nodes == g.number_of_nodes()
print(g.number_of_edges() - num_inner_edges)

subgs = dgl.transform.metis_partition(g, 4, 1)
num_inner_nodes = 0
num_inner_edges = 0
if subgs is not None:
for part_id, subg in subgs.items():
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
ledge_ids = np.nonzero(F.asnumpy(subg.edata['inner_edge']))[0]
num_inner_nodes += len(lnode_ids)
num_inner_edges += len(ledge_ids)
assert np.sum(F.asnumpy(subg.ndata['part_id']) == part_id) == len(lnode_ids)
assert num_inner_nodes == g.number_of_nodes()
print(g.number_of_edges() - num_inner_edges)

@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_in_subgraph():
g1 = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)], 'user', 'follow')
Expand Down
1 change: 1 addition & 0 deletions third_party/METIS
Submodule METIS added at ffaa70
56 changes: 56 additions & 0 deletions tools/partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import argparse
import signal
import dgl
from dgl import backend as F
from dgl.data.utils import load_graphs, save_graphs

def main():
parser = argparse.ArgumentParser(description='Partition a graph')
parser.add_argument('--data', required=True, type=str,
help='The file path of the input graph in the DGL format.')
parser.add_argument('-k', '--num-parts', required=True, type=int,
help='The number of partitions')
parser.add_argument('--num-hops', type=int, default=1,
help='The number of hops of HALO nodes we include in a partition')
parser.add_argument('-m', '--method', required=True, type=str,
help='The partitioning method: random, metis')
parser.add_argument('-o', '--output', required=True, type=str,
help='The output directory of the partitioned results')
args = parser.parse_args()
data_path = args.data
num_parts = args.num_parts
num_hops = args.num_hops
method = args.method
output = args.output

glist, _ = load_graphs(data_path)
g = glist[0]

if args.method == 'metis':
part_dict = dgl.transform.metis_partition(g, num_parts, num_hops)
elif args.method == 'random':
node_parts = np.random.choice(num_parts, g.number_of_nodes())
part_dict = dgl.transform.partition_graph_with_halo(g, node_parts, num_hops)
else:
raise Exception('unknown partitioning method: ' + args.method)

tot_num_inner_edges = 0
for part_id in part_dict:
part = part_dict[part_id]

num_inner_nodes = len(np.nonzero(F.asnumpy(part.ndata['inner_node']))[0])
num_inner_edges = len(np.nonzero(F.asnumpy(part.edata['inner_edge']))[0])
print('part {} has {} nodes and {} edges. {} nodes and {} edges are inside the partition'.format(
part_id, part.number_of_nodes(), part.number_of_edges(),
num_inner_nodes, num_inner_edges))
tot_num_inner_edges += num_inner_edges

# TODO I duplicate some node features.
part.copy_from_parent()
save_graphs(output + '/' + str(part_id) + '.dgl', [part])
print('there are {} edges in the graph and {} edge cuts for {} partitions.'.format(
g.number_of_edges(), g.number_of_edges() - tot_num_inner_edges, len(part_dict)))

if __name__ == '__main__':
main()

0 comments on commit 0e153c4

Please sign in to comment.