Skip to content

Commit

Permalink
Fix the performance issue of graph partitioning in new DGLGraph (dmlc…
Browse files Browse the repository at this point in the history
…#1934)

* fix perf.

* fix.

* accelerate metis.

* fix lint.

* use gklib.

* fix perf.

* fix.

* update metis.

* update launch script

* handle synchronized API.

* fix.

* fix example.

* fix dataloader.

* temp fix.

* temp fix omp.

* distinguish roles.

* initialize iterator of DistDataloader correctly.

* check the correctness of launch script.

* move feature copy to sampler.

* measure mem/network copy time.

* remove

* Revert "measure mem/network copy time."

This reverts commit 86cefdc.

* fix.

* fix

* fix.

* fix cmake.

* disable metis in windows.

* disable metis tests in windows.

* remove test for multigraph.

* fix test.

* fix.

* fix cmake.

* fix.

* revert.

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people authored Aug 10, 2020
1 parent 4097fa2 commit 729ff2e
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 72 deletions.
12 changes: 7 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,13 @@ list(APPEND DGL_LINKER_LIBS dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test

# Compile METIS
set(GKLIB_PATH "${CMAKE_SOURCE_DIR}/third_party/METIS/GKlib")
include(${GKLIB_PATH}/GKlibSystem.cmake)
include_directories(${GKLIB_PATH})
add_subdirectory("third_party/METIS/libmetis/")
list(APPEND DGL_LINKER_LIBS metis)
if(NOT MSVC)
set(GKLIB_PATH "${CMAKE_SOURCE_DIR}/third_party/METIS/GKlib")
include(${GKLIB_PATH}/GKlibSystem.cmake)
include_directories(${GKLIB_PATH})
add_subdirectory("third_party/METIS/libmetis/")
list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC)

# support PARALLEL_ALGORITHMS
if (LIBCXX_ENABLE_PARALLEL_ALGORITHMS)
Expand Down
10 changes: 0 additions & 10 deletions include/dgl/graph_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,6 @@ class GraphOp {
* \return the graph with reordered node Ids
*/
static GraphPtr ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order);

/*!
* \brief Partition a graph with Metis.
* The partitioning algorithm assigns each vertex to a partition.
* \param graph The input graph
* \param k The number of partitions.
* \param vwgt the vertex weight array.
* \return The partition assignments of all vertices.
*/
static IdArray MetisPartition(GraphPtr graph, int32_t k, NDArray vwgt);
};

} // namespace dgl
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False):
'''
# METIS works only on symmetric graphs.
# The METIS runs on the symmetric graph to generate the node assignment to partitions.
from .transform import to_bidirected # avoid cyclic import
start = time.time()
sym_g = to_bidirected(g)
sym_gidx = _CAPI_DGLMakeSymmetric_Hetero(g._graph)
sym_g = DGLHeteroGraph(gidx=sym_gidx)
print('Convert a graph into a bidirected graph: {:.3f} seconds'.format(
time.time() - start))
vwgt = []
Expand Down
4 changes: 0 additions & 4 deletions src/graph/gk_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ namespace dgl {

#if !defined(_WIN32)

namespace {

/*!
* Convert DGL CSR to GKLib CSR.
* GKLib CSR actually stores a CSR object and a CSC object of a graph.
Expand Down Expand Up @@ -97,8 +95,6 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
return aten::CSRMatrix(gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
}

} // namespace

#endif // !defined(_WIN32)

GraphPtr GraphOp::ToBidirectedSimpleImmutableGraph(ImmutableGraphPtr ig) {
Expand Down
15 changes: 13 additions & 2 deletions src/graph/metis_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ using namespace dgl::runtime;

namespace dgl {

IdArray GraphOp::MetisPartition(GraphPtr g, int k, NDArray vwgt_arr) {
#if !defined(_WIN32)

IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr) {
// The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(dgl_id_t));
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g);
Expand Down Expand Up @@ -44,6 +46,9 @@ IdArray GraphOp::MetisPartition(GraphPtr g, int k, NDArray vwgt_arr) {
idx_t options[METIS_NOPTIONS];
METIS_SetDefaultOptions(options);
options[METIS_OPTION_ONDISK] = 1;
options[METIS_OPTION_NITER] = 1;
options[METIS_OPTION_NIPARTS] = 1;
options[METIS_OPTION_DROPEDGES] = 1;

int ret = METIS_PartGraphKway(&nvtxs, // The number of vertices
&ncon, // The number of balancing constraints.
Expand Down Expand Up @@ -78,12 +83,18 @@ IdArray GraphOp::MetisPartition(GraphPtr g, int k, NDArray vwgt_arr) {
return aten::NullArray();
}

#endif // !defined(_WIN32)

DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int k = args[1];
NDArray vwgt = args[2];
*rv = GraphOp::MetisPartition(g.sptr(), k, vwgt);
#if !defined(_WIN32)
*rv = MetisPartition(g.sptr(), k, vwgt);
#else
LOG(FATAL) << "Metis partition does not support Windows.";
#endif // !defined(_WIN32)
});

} // namespace dgl
13 changes: 12 additions & 1 deletion src/graph/transform/metis_partition_hetero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ namespace dgl {

namespace transform {

#if !defined(_WIN32)

IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr) {
// The index type of Metis needs to be compatible with DGL index type.
CHECK_EQ(sizeof(idx_t), sizeof(int64_t))
<< "Metis only supports int64 graph for now";
// This is a symmetric graph, so in-csr and out-csr are the same.
const auto mat = g->GetCSRMatrix(0);
const auto mat = g->GetCSCMatrix(0);
// const auto mat = g->GetInCSR()->ToCSRMatrix();

idx_t nvtxs = g->NumVertices(0);
Expand All @@ -48,6 +50,9 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr) {
idx_t options[METIS_NOPTIONS];
METIS_SetDefaultOptions(options);
options[METIS_OPTION_ONDISK] = 1;
options[METIS_OPTION_NITER] = 1;
options[METIS_OPTION_NIPARTS] = 1;
options[METIS_OPTION_DROPEDGES] = 1;

int ret = METIS_PartGraphKway(
&nvtxs, // The number of vertices
Expand Down Expand Up @@ -82,6 +87,8 @@ IdArray MetisPartition(UnitGraphPtr g, int k, NDArray vwgt_arr) {
return aten::NullArray();
}

#endif // !defined(_WIN32)

DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0];
Expand All @@ -92,7 +99,11 @@ DGL_REGISTER_GLOBAL("partition._CAPI_DGLMetisPartition_Hetero")
auto ugptr = hgptr->relation_graphs()[0];
int k = args[1];
NDArray vwgt = args[2];
#if !defined(_WIN32)
*rv = MetisPartition(ugptr, k, vwgt);
#else
LOG(FATAL) << "Metis partition does not support Windows.";
#endif // !defined(_WIN32)
});
} // namespace transform
} // namespace dgl
55 changes: 51 additions & 4 deletions src/graph/transform/partition_hetero.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
* \brief Call Metis partitioning
*/

#if !defined(_WIN32)
#include <GKlib.h>
#endif // !defined(_WIN32)

#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>

Expand All @@ -14,6 +18,11 @@ using namespace dgl::runtime;

namespace dgl {

#if !defined(_WIN32)
gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row);
aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row);
#endif // !defined(_WIN32)

namespace transform {

class HaloHeteroSubgraph : public HeteroSubgraph {
Expand All @@ -22,11 +31,21 @@ class HaloHeteroSubgraph : public HeteroSubgraph {
};

HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) {
auto format = ug->GetCreatedFormats();
// We only need to reorder one of the graph structure.
// Only to in_csr for now
auto csrmat = ug->GetCSRMatrix(0);
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat);
if (format & csc_code) {
auto cscmat = ug->GetCSCMatrix(0);
auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order);
return UnitGraph::CreateFromCSC(ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats());
} else if (format & csr_code) {
auto csrmat = ug->GetCSRMatrix(0);
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats());
} else {
auto coomat = ug->GetCOOMatrix(0);
auto new_coomat = aten::COOReorder(coomat, new_order, new_order);
return UnitGraph::CreateFromCOO(ug->NumVertexTypes(), new_coomat, ug->GetAllowedFormats());
}
}

HaloHeteroSubgraph GetSubgraphWithHalo(std::shared_ptr<HeteroGraph> hg,
Expand Down Expand Up @@ -256,5 +275,33 @@ DGL_REGISTER_GLOBAL("partition._CAPI_GetHaloSubgraphInnerNodes_Hetero")
*rv = gptr->inner_nodes[0];
});


DGL_REGISTER_GLOBAL("partition._CAPI_DGLMakeSymmetric_Hetero")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef g = args[0];
auto hgptr = std::dynamic_pointer_cast<HeteroGraph>(g.sptr());
CHECK(hgptr) << "Invalid HeteroGraph object";
CHECK_EQ(hgptr->relation_graphs().size(), 1)
<< "Metis partition only supports homogeneous graph";
auto ugptr = hgptr->relation_graphs()[0];

#if !defined(_WIN32)
// TODO(zhengda) should we get whatever CSR exists in the graph.
gk_csr_t *gk_csr = Convert2GKCsr(ugptr->GetCSCMatrix(0), true);
gk_csr_t *sym_gk_csr = gk_csr_MakeSymmetric(gk_csr, GK_CSR_SYM_SUM);
auto mat = Convert2DGLCsr(sym_gk_csr, true);
gk_csr_Free(&gk_csr);
gk_csr_Free(&sym_gk_csr);

auto new_ugptr = UnitGraph::CreateFromCSC(ugptr->NumVertexTypes(), mat,
ugptr->GetAllowedFormats());
std::vector<HeteroGraphPtr> rel_graphs = {new_ugptr};
*rv = HeteroGraphRef(std::make_shared<HeteroGraph>(
hgptr->meta_graph(), rel_graphs, hgptr->NumVerticesPerType()));
#else
LOG(FATAL) << "The fast version of making symmetric graph is not supported in Windows.";
#endif // !defined(_WIN32)
});

} // namespace transform
} // namespace dgl
51 changes: 8 additions & 43 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from scipy import sparse as spsp
import networkx as nx
import numpy as np
import os
import dgl
import dgl.function as fn
import backend as F
Expand Down Expand Up @@ -275,15 +276,6 @@ def test_to_bidirected():
big = dgl.to_bidirected(g, copy_ndata=True)
assert F.array_equal(g.nodes['user'].data['h'], big.nodes['user'].data['h'])

# test multigraph
g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])))
raise_error = False
try:
big = dgl.to_bidirected(g)
except:
raise_error = True
assert raise_error

def test_add_reverse_edges():
# homogeneous graph
g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])))
Expand Down Expand Up @@ -490,12 +482,12 @@ def test_laplacian_lambda_max():
assert l_max < 2 + eps
'''

def create_large_graph_index(num_nodes):
def create_large_graph(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))

return from_scipy_sparse_matrix(spm, True)
return dgl.graph(spm)

def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids)
Expand All @@ -507,46 +499,19 @@ def get_nodeflow(g, node_ids, num_layers):

@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_partition_with_halo():
g = dgl.DGLGraphStale(create_large_graph_index(1000), readonly=True)
g = create_large_graph(1000)
node_part = np.random.choice(4, g.number_of_nodes())
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2)
for part_id, subg in subgs.items():
node_ids = np.nonzero(node_part == part_id)[0]
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
nf = get_nodeflow(g, node_ids, 2)
lnf = get_nodeflow(subg, lnode_ids, 2)
for i in range(nf.num_layers):
layer_nids1 = F.asnumpy(nf.layer_parent_nid(i))
layer_nids2 = lnf.layer_parent_nid(i)
layer_nids2 = F.asnumpy(F.gather_row(subg.ndata[dgl.NID], layer_nids2))
assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2))

for i in range(nf.num_blocks):
block_eids1 = F.asnumpy(nf.block_parent_eid(i))
block_eids2 = lnf.block_parent_eid(i)
block_eids2 = F.asnumpy(F.gather_row(subg.edata[dgl.EID], block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2))

subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True)
for part_id, subg in subgs.items():
node_ids = np.nonzero(node_part == part_id)[0]
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
assert np.all(np.sort(F.asnumpy(subg.ndata['orig_id'])[lnode_ids]) == node_ids)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_metis_partition():
# TODO(zhengda) Metis fails to partition a small graph.
g = dgl.DGLGraphStale(create_large_graph_index(1000), readonly=True)
check_metis_partition(g, 0)
check_metis_partition(g, 1)
check_metis_partition(g, 2)
check_metis_partition_with_constraint(g)

@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_hetero_metis_partition():
# TODO(zhengda) Metis fails to partition a small graph.
g = dgl.DGLGraphStale(create_large_graph_index(1000), readonly=True)
g = dgl.as_heterograph(g)
g = create_large_graph(1000)
check_metis_partition(g, 0)
check_metis_partition(g, 1)
check_metis_partition(g, 2)
Expand Down Expand Up @@ -635,10 +600,10 @@ def check_metis_partition(g, extra_hops):

@unittest.skipIf(F._default_context_str == 'gpu', reason="It doesn't support GPU")
def test_reorder_nodes():
g = dgl.DGLGraphStale(create_large_graph_index(1000), readonly=True)
g = create_large_graph(1000)
new_nids = np.random.permutation(g.number_of_nodes())
# TODO(zhengda) we need to test both CSR and COO.
new_g = dgl.transform.reorder_nodes(g, new_nids)
new_g = dgl.partition.reorder_nodes(g, new_nids)
new_in_deg = new_g.in_degrees()
new_out_deg = new_g.out_degrees()
in_deg = g.in_degrees()
Expand Down
3 changes: 3 additions & 0 deletions tests/distributed/test_dist_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def check_server_client(shared_mem, num_servers, num_clients):

print('clients have terminated')

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
os.environ['DGL_DIST_MODE'] = 'distributed'
Expand Down Expand Up @@ -256,6 +257,7 @@ def test_standalone():
part_config='/tmp/dist_graph/{}.json'.format(graph_name))
check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split():
#prepare_dist()
g = create_random_graph(10000)
Expand Down Expand Up @@ -310,6 +312,7 @@ def set_roles(num_clients):
edges5 = F.cat([edges3, edges4], 0)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split_even():
#prepare_dist(1)
g = create_random_graph(10000)
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ def check_partition(g, part_method, reshuffle):
assert F.dtype(eid2pid) in (F.int32, F.int64)
assert np.all(F.asnumpy(eid2pid) == edge_map)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_partition():
g = create_random_graph(10000)
check_partition(g, 'metis', True)
check_partition(g, 'metis', False)
check_partition(g, 'random', True)
check_partition(g, 'random', False)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_hetero_partition():
g = create_random_graph(10000)
check_partition(g, 'metis', True)
Expand Down

0 comments on commit 729ff2e

Please sign in to comment.