Skip to content

Commit

Permalink
[Distributed] Reduce memory consumption in graph partitioning (dmlc#1823
Browse files Browse the repository at this point in the history
)

* save mem.

* save mem.

* reduce mem

* fix test

* fix lint

* fix test

* fix.

* fix.

* fix.

* fix.

* fix lint.

* fix backend operator.

* fix tensorflow operators.

* fix.

* revert change in mxnet operator.

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
zheng-da and Ubuntu authored Jul 18, 2020
1 parent 7645c66 commit ea420c0
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 71 deletions.
3 changes: 0 additions & 3 deletions include/dgl/graph_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,6 @@ struct NegSubgraph: public Subgraph {
struct HaloSubgraph: public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes;

/*! \brief Indicate if an edge belongs to the partition. */
IdArray inner_edges;
};

// Define SubgraphRef
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
tmp = np.nonzero(tmp)[0]
return nd.array(tmp, ctx=input.context, dtype=input.dtype)
return nd.array(tmp, ctx=input.context, dtype=tmp.dtype)

def sort_1d(input):
# TODO: this isn't an ideal implementation.
Expand Down
5 changes: 3 additions & 2 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def initialize_context():
tf.zeros(1)

def as_scalar(data):
return data.numpy().asscalar()
data = data.numpy()
return data if np.isscalar(data) else data.asscalar()


def get_preferred_sparse_format():
Expand Down Expand Up @@ -384,7 +385,7 @@ def full_1d(length, fill_value, dtype, ctx):


def nonzero_1d(input):
nonzero_bool = (input != False)
nonzero_bool = tf.cast(input, tf.bool)
return tf.reshape(tf.where(nonzero_bool), (-1, ))


Expand Down
4 changes: 2 additions & 2 deletions python/dgl/distributed/graph_partition_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def __init__(self, part_id, num_parts, node_map, edge_map, part_graph):
self._part_id = int(part_id)
self._num_partitions = int(num_parts)
self._nid2partid = F.tensor(node_map)
assert F.dtype(self._nid2partid) in (F.int32, F.int64), \
assert F.dtype(self._nid2partid) == F.int64, \
'the node map must be stored in an integer array'
self._eid2partid = F.tensor(edge_map)
assert F.dtype(self._eid2partid) in (F.int32, F.int64), \
assert F.dtype(self._eid2partid) == F.int64, \
'the edge map must be stored in an integer array'
# Get meta data of the partition book.
self._partition_meta_data = []
Expand Down
56 changes: 26 additions & 30 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,49 +254,36 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
Indicate whether to balance the edges.
'''
if num_parts == 1:
client_parts = {0: g}
parts = {0: g}
node_parts = F.zeros((g.number_of_nodes(),), F.int64, F.cpu())
g.ndata[NID] = F.arange(0, g.number_of_nodes())
g.edata[EID] = F.arange(0, g.number_of_edges())
g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int64, F.cpu())
g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int64, F.cpu())
g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int8, F.cpu())
g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int8, F.cpu())
if reshuffle:
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes())
g.edata['orig_id'] = F.arange(0, g.number_of_edges())
elif part_method == 'metis':
node_parts = metis_partition_assignment(g, num_parts, balance_ntypes=balance_ntypes,
balance_edges=balance_edges)
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
elif part_method == 'random':
node_parts = random_choice(num_parts, g.number_of_nodes())
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
else:
raise Exception('Unknown partitioning method: ' + part_method)

# Let's calculate edge assignment.
# TODO(zhengda) we should replace int64 with int16. int16 should be sufficient.
start = time.time()
if not reshuffle:
start = time.time()
# We only optimize for reshuffled case. So it's fine to use int64 here.
edge_parts = np.zeros((g.number_of_edges(),), dtype=np.int64) - 1
num_edges = 0
num_nodes = 0
lnodes_list = [] # The node ids of each partition
ledges_list = [] # The edge Ids of each partition
for part_id in range(num_parts):
part = client_parts[part_id]
# To get the edges in the input graph, we should use original node Ids.
data_name = 'orig_id' if reshuffle else NID
local_nodes = F.boolean_mask(part.ndata[data_name], part.ndata['inner_node'])
local_edges = g.in_edges(local_nodes, form='eid')
if not reshuffle:
for part_id in parts:
part = parts[part_id]
# To get the edges in the input graph, we should use original node Ids.
local_edges = F.boolean_mask(part.edata[EID], part.edata['inner_edge'])
edge_parts[F.asnumpy(local_edges)] = part_id
num_edges += len(local_edges)
num_nodes += len(local_nodes)
lnodes_list.append(local_nodes)
ledges_list.append(local_edges)
assert num_edges == g.number_of_edges()
assert num_nodes == g.number_of_nodes()
print('Calculate edge assignment: {:.3f} seconds'.format(time.time() - start))
print('Calculate edge assignment: {:.3f} seconds'.format(time.time() - start))

os.makedirs(out_path, mode=0o775, exist_ok=True)
tot_num_inner_edges = 0
Expand All @@ -314,8 +301,14 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
# With reshuffling, we can ensure that all nodes and edges are reshuffled
# and are in contiguous Id space.
if num_parts > 1:
node_map_val = np.cumsum([len(lnodes) for lnodes in lnodes_list]).tolist()
edge_map_val = np.cumsum([len(ledges) for ledges in ledges_list]).tolist()
node_map_val = [F.as_scalar(F.sum(F.astype(parts[i].ndata['inner_node'], F.int64),
0)) for i in parts]
node_map_val = np.cumsum(node_map_val).tolist()
assert node_map_val[-1] == g.number_of_nodes()
edge_map_val = [F.as_scalar(F.sum(F.astype(parts[i].edata['inner_edge'], F.int64),
0)) for i in parts]
edge_map_val = np.cumsum(edge_map_val).tolist()
assert edge_map_val[-1] == g.number_of_edges()
else:
node_map_val = [g.number_of_nodes()]
edge_map_val = [g.number_of_edges()]
Expand All @@ -330,14 +323,17 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
'node_map': node_map_val,
'edge_map': edge_map_val}
for part_id in range(num_parts):
part = client_parts[part_id]
part = parts[part_id]

# Get the node/edge features of each partition.
node_feats = {}
edge_feats = {}
if num_parts > 1:
local_nodes = lnodes_list[part_id]
local_edges = ledges_list[part_id]
# To get the edges in the input graph, we should use original node Ids.
ndata_name = 'orig_id' if reshuffle else NID
edata_name = 'orig_id' if reshuffle else EID
local_nodes = F.boolean_mask(part.ndata[ndata_name], part.ndata['inner_node'])
local_edges = F.boolean_mask(part.edata[edata_name], part.edata['inner_edge'])
print('part {} has {} nodes and {} edges.'.format(
part_id, part.number_of_nodes(), part.number_of_edges()))
print('{} nodes and {} edges are inside the partition'.format(
Expand Down
6 changes: 1 addition & 5 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def node_halo_subgraph(self, v, num_hops):
v_array = v.todgltensor()
subg = _CAPI_DGLGetSubgraphWithHalo(self, v_array, num_hops)
inner_nodes = _CAPI_GetHaloSubgraphInnerNodes(subg)
inner_edges = _CAPI_GetHaloSubgraphInnerEdges(subg)
return subg, inner_nodes, inner_edges
return subg, inner_nodes

def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs.
Expand Down Expand Up @@ -1297,7 +1296,4 @@ def create_graph_index(graph_data, readonly):
def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)

def _get_halo_subgraph_inner_edge(halo_subg):
return _CAPI_GetHaloSubgraphInnerEdges(halo_subg)

_init_api("dgl.graph_index")
25 changes: 18 additions & 7 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,35 +962,46 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
# that all edges in a partition are in the contiguous Id space.
orig_eids = _CAPI_DGLReassignEdges(g._graph, True)
orig_eids = utils.toindex(orig_eids)
g.edata['orig_id'] = orig_eids.tousertensor()
orig_eids = orig_eids.tousertensor()
orig_nids = g.ndata['orig_id']
print('Reshuffle nodes and edges: {:.3f} seconds'.format(time.time() - start))

start = time.time()
subgs = _CAPI_DGLPartitionWithHalo(g._graph, node_part.todgltensor(), extra_cached_hops)
# g is no longer needed. Free memory.
g = None
print('Split the graph: {:.3f} seconds'.format(time.time() - start))
subg_dict = {}
node_part = node_part.tousertensor()
start = time.time()

# This creaets a subgraph from subgraphs returned from the CAPI above.
def create_subgraph(subg, induced_nodes, induced_edges):
subg1 = DGLGraph(graph_data=subg.graph, readonly=True)
subg1.ndata[NID] = induced_nodes.tousertensor()
subg1.edata[EID] = induced_edges.tousertensor()
return subg1

for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg)
subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
subg = create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node
subg.ndata['part_id'] = F.gather_row(node_part, subg.parent_nid)
subg.ndata['part_id'] = F.gather_row(node_part, subg.ndata[NID])
if reshuffle:
subg.ndata['orig_id'] = F.gather_row(g.ndata['orig_id'], subg.ndata[NID])
subg.edata['orig_id'] = F.gather_row(g.edata['orig_id'], subg.edata[EID])
subg.ndata['orig_id'] = F.gather_row(orig_nids, subg.ndata[NID])
subg.edata['orig_id'] = F.gather_row(orig_eids, subg.edata[EID])

if extra_cached_hops >= 1:
inner_edge = F.zeros((subg.number_of_edges(),), F.int64, F.cpu())
inner_edge = F.zeros((subg.number_of_edges(),), F.int8, F.cpu())
inner_nids = F.nonzero_1d(subg.ndata['inner_node'])
# TODO(zhengda) we need to fix utils.toindex() to avoid the dtype cast below.
inner_nids = F.astype(inner_nids, F.int64)
inner_eids = subg.in_edges(inner_nids, form='eid')
inner_edge = F.scatter_row(inner_edge, inner_eids,
F.ones((len(inner_eids),), F.dtype(inner_edge), F.cpu()))
else:
inner_edge = F.ones((subg.number_of_edges(),), F.int64, F.cpu())
inner_edge = F.ones((subg.number_of_edges(),), F.int8, F.cpu())
subg.edata['inner_edge'] = inner_edge
subg_dict[i] = subg
print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start))
Expand Down
13 changes: 1 addition & 12 deletions src/graph/graph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
auto orig_nodes = all_nodes;

std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;
std::vector<int> inner_edges;

// When we deal with in-edges, we need to do two things:
// * find the edges inside the partition and the edges between partitions.
Expand All @@ -436,7 +435,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
inner_edges.push_back(it1 != orig_nodes.end());
}
// We need to expand only if the node hasn't been seen before.
auto it = all_nodes.find(src_data[i]);
Expand All @@ -463,7 +461,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]);
inner_edges.push_back(false);
// If we haven't seen this node.
auto it = all_nodes.find(src_data[i]);
if (it == all_nodes.end()) {
Expand Down Expand Up @@ -502,8 +499,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
halo_subg.graph = subg;
halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);
halo_subg.induced_edges = aten::VecToIdArray(edge_eid);
// TODO(zhengda) we need to switch to 8 bytes afterwards.
halo_subg.inner_nodes = aten::VecToIdArray<int>(inner_nodes, 32);
halo_subg.inner_edges = aten::VecToIdArray<int>(inner_edges, 32);
return halo_subg;
}

Expand Down Expand Up @@ -603,14 +600,6 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes")
*rv = gptr->inner_nodes;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be immutable graph";
*rv = gptr->inner_edges;
});

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0];
Expand Down
4 changes: 2 additions & 2 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,13 @@ def test_partition_with_halo():
for i in range(nf.num_layers):
layer_nids1 = F.asnumpy(nf.layer_parent_nid(i))
layer_nids2 = lnf.layer_parent_nid(i)
layer_nids2 = F.asnumpy(F.gather_row(subg.parent_nid, layer_nids2))
layer_nids2 = F.asnumpy(F.gather_row(subg.ndata[dgl.NID], layer_nids2))
assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2))

for i in range(nf.num_blocks):
block_eids1 = F.asnumpy(nf.block_parent_eid(i))
block_eids2 = lnf.block_parent_eid(i)
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2))
block_eids2 = F.asnumpy(F.gather_row(subg.edata[dgl.EID], block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2))

subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True)
Expand Down
8 changes: 1 addition & 7 deletions tests/graph_index/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def create_large_graph_index(num_nodes):
def test_node_subgraph_with_halo():
gi = create_large_graph_index(1000)
nodes = np.random.choice(gi.number_of_nodes(), 100, replace=False)
halo_subg, inner_node, inner_edge = gi.node_halo_subgraph(toindex(nodes), 2)
halo_subg, inner_node = gi.node_halo_subgraph(toindex(nodes), 2)

# Check if edges in the subgraph are in the original graph.
for s, d, e in zip(*halo_subg.graph.edges()):
Expand All @@ -111,12 +111,6 @@ def test_node_subgraph_with_halo():
inner_node_ids = halo_subg.induced_nodes.tonumpy()[inner_node_ids]
assert np.all(np.sort(inner_node_ids) == np.sort(nodes))

# Check if the inner edge labels are correct.
inner_edge = inner_edge.asnumpy()
inner_edge_ids = halo_subg.induced_edges.tonumpy()[inner_edge > 0]
subg = gi.node_subgraph(toindex(nodes))
assert np.all(np.sort(subg.induced_edges.tonumpy()) == np.sort(inner_edge_ids))

if __name__ == '__main__':
test_node_subgraph()
test_node_subgraph_with_halo()
Expand Down

0 comments on commit ea420c0

Please sign in to comment.