Skip to content

Commit

Permalink
[Distributed] Fix bugs in partitioning on heterogeneous graphs. (dmlc…
Browse files Browse the repository at this point in the history
…#3085)

* fix bugs in partitioning on heterogeneous graphs.

* fix.

* fix.

* fix example.

* fix.

* fix test.

* fix.

* fix.

* fix.

* fix tests.

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Zheng <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2021
1 parent 04dce1e commit 0884d02
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 47 deletions.
26 changes: 21 additions & 5 deletions examples/pytorch/rgcn/experimental/entity_classify_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def run(args, device, data):
backward_time = 0
update_time = 0
number_train = 0
number_input = 0

step_time = []
iter_t = []
Expand All @@ -441,6 +442,7 @@ def run(args, device, data):
for step, sample_data in enumerate(dataloader):
seeds, blocks = sample_data
number_train += seeds.shape[0]
number_input += np.sum([blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes])
tic_step = time.time()
sample_time += tic_step - start
sample_t.append(tic_step - start)
Expand Down Expand Up @@ -484,8 +486,8 @@ def run(args, device, data):
np.sum(backward_t[-args.log_every:]), np.sum(update_t[-args.log_every:])))
start = time.time()

print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #number_train: {}'.format(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), number_train))
print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}'.format(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), number_train, number_input))
epoch += 1

start = time.time()
Expand All @@ -505,9 +507,23 @@ def main(args):
print('rank:', g.rank())

pb = g.get_partition_book()
train_nid = dgl.distributed.node_split(g.nodes['paper'].data['train_mask'], pb, ntype='paper', force_even=True)
val_nid = dgl.distributed.node_split(g.nodes['paper'].data['val_mask'], pb, ntype='paper', force_even=True)
test_nid = dgl.distributed.node_split(g.nodes['paper'].data['test_mask'], pb, ntype='paper', force_even=True)
if 'trainer_id' in g.nodes['paper'].data:
train_nid = dgl.distributed.node_split(g.nodes['paper'].data['train_mask'],
pb, ntype='paper', force_even=True,
node_trainer_ids=g.nodes['paper'].data['trainer_id'])
val_nid = dgl.distributed.node_split(g.nodes['paper'].data['val_mask'],
pb, ntype='paper', force_even=True,
node_trainer_ids=g.nodes['paper'].data['trainer_id'])
test_nid = dgl.distributed.node_split(g.nodes['paper'].data['test_mask'],
pb, ntype='paper', force_even=True,
node_trainer_ids=g.nodes['paper'].data['trainer_id'])
else:
train_nid = dgl.distributed.node_split(g.nodes['paper'].data['train_mask'],
pb, ntype='paper', force_even=True)
val_nid = dgl.distributed.node_split(g.nodes['paper'].data['val_mask'],
pb, ntype='paper', force_even=True)
test_nid = dgl.distributed.node_split(g.nodes['paper'].data['test_mask'],
pb, ntype='paper', force_even=True)
local_nid = pb.partid2nids(pb.partid, 'paper').detach().numpy()
print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format(
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)),
Expand Down
10 changes: 8 additions & 2 deletions examples/pytorch/rgcn/experimental/partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def load_ogb(dataset):
help='turn the graph into an undirected graph.')
argparser.add_argument('--balance_edges', action='store_true',
help='balance the number of edges in each partition.')
argparser.add_argument('--num_trainers_per_machine', type=int, default=1,
help='the number of trainers per machine. The trainer ids are stored\
in the node feature \'trainer_id\'')
argparser.add_argument('--output', type=str, default='data',
help='Output path of partitioned graph.')
args = argparser.parse_args()

start = time.time()
Expand All @@ -84,7 +89,8 @@ def load_ogb(dataset):
else:
balance_ntypes = None

dgl.distributed.partition_graph(g, args.dataset, args.num_parts, 'data',
dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output,
part_method=args.part_method,
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges)
balance_edges=args.balance_edges,
num_trainers_per_machine=args.num_trainers_per_machine)
128 changes: 94 additions & 34 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,84 @@ def load_partition_book(part_config, part_id, graph=None):
return BasicPartitionBook(part_id, num_parts, node_map, edge_map, graph), \
part_metadata['graph_name'], ntypes, etypes

def _get_orig_ids(g, sim_g, reshuffle, orig_nids, orig_eids):
'''Convert/construct the original node IDs and edge IDs.
It handles multiple cases:
* If the graph has been reshuffled and it's a homogeneous graph, we just return
the original node IDs and edge IDs in the inputs.
* If the graph has been reshuffled and it's a heterogeneous graph, we need to
split the original node IDs and edge IDs in the inputs based on the node types
and edge types.
* If the graph is not shuffled, the original node IDs and edge IDs don't change.
Parameters
----------
g : DGLGraph
The input graph for partitioning.
sim_g : DGLGraph
The homogeneous version of the input graph.
reshuffle : bool
Whether the input graph is reshuffled during partitioning.
orig_nids : tensor or None
The original node IDs after the input graph is reshuffled.
orig_eids : tensor or None
The original edge IDs after the input graph is reshuffled.
Returns
-------
tensor or dict of tensors, tensor or dict of tensors
'''
is_hetero = len(g.etypes) > 1 or len(g.ntypes) > 1
if reshuffle and is_hetero:
# Get the type IDs
orig_ntype = F.gather_row(sim_g.ndata[NTYPE], orig_nids)
orig_etype = F.gather_row(sim_g.edata[ETYPE], orig_eids)
# Mapping between shuffled global IDs to original per-type IDs
orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids)
orig_eids = F.gather_row(sim_g.edata[EID], orig_eids)
orig_nids = {ntype: F.boolean_mask(orig_nids, orig_ntype == g.get_ntype_id(ntype)) \
for ntype in g.ntypes}
orig_eids = {etype: F.boolean_mask(orig_eids, orig_etype == g.get_etype_id(etype)) \
for etype in g.etypes}
elif not reshuffle and not is_hetero:
orig_nids = F.arange(0, sim_g.number_of_nodes())
orig_eids = F.arange(0, sim_g.number_of_edges())
elif not reshuffle:
orig_nids = {ntype: F.arange(0, g.number_of_nodes(ntype)) for ntype in g.ntypes}
orig_eids = {etype: F.arange(0, g.number_of_edges(etype)) for etype in g.etypes}
return orig_nids, orig_eids

def _set_trainer_ids(g, sim_g, node_parts):
'''Set the trainer IDs for each node and edge on the input graph.
The trainer IDs will be stored as node data and edge data in the input graph.
Parameters
----------
g : DGLGraph
The input graph for partitioning.
sim_g : DGLGraph
The homogeneous version of the input graph.
node_parts : tensor
The node partition ID for each node in `sim_g`.
'''
if len(g.etypes) == 1:
g.ndata['trainer_id'] = node_parts
# An edge is assigned to a partition based on its destination node.
g.edata['trainer_id'] = F.gather_row(node_parts, g.edges()[1])
else:
for ntype_id, ntype in enumerate(g.ntypes):
type_idx = sim_g.ndata[NTYPE] == ntype_id
orig_nid = F.boolean_mask(sim_g.ndata[NID], type_idx)
trainer_id = F.zeros((len(orig_nid),), F.dtype(node_parts), F.cpu())
F.scatter_row_inplace(trainer_id, orig_nid, F.boolean_mask(node_parts, type_idx))
g.nodes[ntype].data['trainer_id'] = trainer_id
for _, etype, dst_type in g.canonical_etypes:
# An edge is assigned to a partition based on its destination node.
trainer_id = F.gather_row(g.nodes[dst_type].data['trainer_id'], g.edges(etype=etype)[1])
g.edges[etype].data['trainer_id'] = trainer_id

def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
reshuffle=True, balance_ntypes=None, balance_edges=False, return_mapping=False,
num_trainers_per_machine=1):
Expand Down Expand Up @@ -420,7 +498,7 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
'''
def get_homogeneous(g, balance_ntypes):
if len(g.etypes) == 1:
sim_g = g
sim_g = to_homogeneous(g)
if isinstance(balance_ntypes, dict):
assert len(balance_ntypes) == 1
bal_ntypes = list(balance_ntypes.values())[0]
Expand Down Expand Up @@ -459,7 +537,7 @@ def get_homogeneous(g, balance_ntypes):
"For heterogeneous graphs, reshuffle must be enabled.")

if num_parts == 1:
sim_g = to_homogeneous(g)
sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)
assert num_trainers_per_machine >= 1
if num_trainers_per_machine > 1:
# First partition the whole graph to each trainer and save the trainer ids in
Expand All @@ -469,21 +547,19 @@ def get_homogeneous(g, balance_ntypes):
balance_ntypes=balance_ntypes,
balance_edges=balance_edges,
mode='k-way')
g.ndata['trainer_id'] = node_parts
g.edata['trainer_id'] = node_parts[g.edges()[1]]
_set_trainer_ids(g, sim_g, node_parts)

node_parts = F.zeros((sim_g.number_of_nodes(),), F.int64, F.cpu())
parts = {}
parts = {0: sim_g.clone()}
orig_nids = parts[0].ndata[NID] = F.arange(0, sim_g.number_of_nodes())
orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges())
# For one partition, we don't really shuffle nodes and edges. We just need to simulate
# it and set node data and edge data of orig_id.
if reshuffle:
parts[0] = sim_g.clone()
parts[0].ndata[NID] = parts[0].ndata['orig_id'] = F.arange(0, sim_g.number_of_nodes())
parts[0].edata[EID] = parts[0].edata['orig_id'] = F.arange(0, sim_g.number_of_edges())
orig_nids = parts[0].ndata['orig_id']
orig_eids = parts[0].edata['orig_id']
else:
parts[0] = sim_g.clone()
orig_nids = parts[0].ndata[NID] = F.arange(0, sim_g.number_of_nodes())
orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges())
parts[0].ndata['orig_id'] = orig_nids
parts[0].edata['orig_id'] = orig_eids
if return_mapping:
orig_nids, orig_eids = _get_orig_ids(g, sim_g, False, orig_nids, orig_eids)
parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),), F.int8, F.cpu())
parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),), F.int8, F.cpu())
elif part_method in ('metis', 'random'):
Expand All @@ -498,11 +574,11 @@ def get_homogeneous(g, balance_ntypes):
balance_ntypes=balance_ntypes,
balance_edges=balance_edges,
mode='k-way')
g.ndata['trainer_id'] = node_parts
_set_trainer_ids(g, sim_g, node_parts)

# And then coalesce the partitions of trainers on the same machine into one
# larger partition.
node_parts = node_parts // num_trainers_per_machine
node_parts = F.floor_div(node_parts, num_trainers_per_machine)
else:
node_parts = metis_partition_assignment(sim_g, num_parts,
balance_ntypes=balance_ntypes,
Expand All @@ -511,24 +587,8 @@ def get_homogeneous(g, balance_ntypes):
node_parts = random_choice(num_parts, sim_g.number_of_nodes())
parts, orig_nids, orig_eids = partition_graph_with_halo(sim_g, node_parts, num_hops,
reshuffle=reshuffle)
is_hetero = len(g.etypes) > 1 or len(g.ntypes) > 1
if reshuffle and return_mapping and is_hetero:
# Get the type IDs
orig_ntype = F.gather_row(sim_g.ndata[NTYPE], orig_nids)
orig_etype = F.gather_row(sim_g.edata[ETYPE], orig_eids)
# Mapping between shuffled global IDs to original per-type IDs
orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids)
orig_eids = F.gather_row(sim_g.edata[EID], orig_eids)
orig_nids = {ntype: F.boolean_mask(orig_nids, orig_ntype == g.get_ntype_id(ntype)) \
for ntype in g.ntypes}
orig_eids = {etype: F.boolean_mask(orig_eids, orig_etype == g.get_etype_id(etype)) \
for etype in g.etypes}
elif not reshuffle and not is_hetero and return_mapping:
orig_nids = F.arange(0, sim_g.number_of_nodes())
orig_eids = F.arange(0, sim_g.number_of_edges())
elif not reshuffle and return_mapping:
orig_nids = {ntype: F.arange(0, g.number_of_nodes(ntype)) for ntype in g.ntypes}
orig_eids = {etype: F.arange(0, g.number_of_edges(etype)) for etype in g.etypes}
if return_mapping:
orig_nids, orig_eids = _get_orig_ids(g, sim_g, reshuffle, orig_nids, orig_eids)
else:
raise Exception('Unknown partitioning method: ' + part_method)

Expand Down
41 changes: 35 additions & 6 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,16 @@ def verify_graph_feats(g, gpb, part, node_feats, edge_feats):
edata = F.gather_row(edge_feats[etype + '/' + name], local_eids)
assert np.all(F.asnumpy(edata == true_feats))

def check_hetero_partition(hg, part_method):
def check_hetero_partition(hg, part_method, num_parts=4, num_trainers_per_machine=1):
hg.nodes['n1'].data['labels'] = F.arange(0, hg.number_of_nodes('n1'))
hg.nodes['n1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_nodes('n1'), 10), F.float32)
hg.edges['r1'].data['feats'] = F.tensor(np.random.randn(hg.number_of_edges('r1'), 10), F.float32)
hg.edges['r1'].data['labels'] = F.arange(0, hg.number_of_edges('r1'))
num_parts = 4
num_hops = 1

orig_nids, orig_eids = partition_graph(hg, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method=part_method, reshuffle=True, return_mapping=True)
part_method=part_method, reshuffle=True, return_mapping=True,
num_trainers_per_machine=num_trainers_per_machine)
assert len(orig_nids) == len(hg.ntypes)
assert len(orig_eids) == len(hg.etypes)
for ntype in hg.ntypes:
Expand All @@ -180,6 +180,18 @@ def check_hetero_partition(hg, part_method):
shuffled_elabels = []
for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
if num_trainers_per_machine > 1:
for ntype in hg.ntypes:
name = ntype + '/trainer_id'
assert name in node_feats
part_ids = F.floor_div(node_feats[name], num_trainers_per_machine)
assert np.all(F.asnumpy(part_ids) == i)

for etype in hg.etypes:
name = etype + '/trainer_id'
assert name in edge_feats
part_ids = F.floor_div(edge_feats[name], num_trainers_per_machine)
assert np.all(F.asnumpy(part_ids) == i)
# Verify the mapping between the reshuffled IDs and the original IDs.
# These are partition-local IDs.
part_src_ids, part_dst_ids = part_g.edges()
Expand Down Expand Up @@ -224,22 +236,34 @@ def check_hetero_partition(hg, part_method):
assert np.all(orig_labels == F.asnumpy(hg.nodes['n1'].data['labels']))
assert np.all(orig_elabels == F.asnumpy(hg.edges['r1'].data['labels']))

def check_partition(g, part_method, reshuffle):
def check_partition(g, part_method, reshuffle, num_parts=4, num_trainers_per_machine=1):
g.ndata['labels'] = F.arange(0, g.number_of_nodes())
g.ndata['feats'] = F.tensor(np.random.randn(g.number_of_nodes(), 10), F.float32)
g.edata['feats'] = F.tensor(np.random.randn(g.number_of_edges(), 10), F.float32)
g.update_all(fn.copy_src('feats', 'msg'), fn.sum('msg', 'h'))
g.update_all(fn.copy_edge('feats', 'msg'), fn.sum('msg', 'eh'))
num_parts = 4
num_hops = 2

orig_nids, orig_eids = partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method=part_method, reshuffle=reshuffle, return_mapping=True)
part_method=part_method, reshuffle=reshuffle, return_mapping=True,
num_trainers_per_machine=num_trainers_per_machine)
part_sizes = []
shuffled_labels = []
shuffled_edata = []
for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
if num_trainers_per_machine > 1:
for ntype in g.ntypes:
name = ntype + '/trainer_id'
assert name in node_feats
part_ids = F.floor_div(node_feats[name], num_trainers_per_machine)
assert np.all(F.asnumpy(part_ids) == i)

for etype in g.etypes:
name = etype + '/trainer_id'
assert name in edge_feats
part_ids = F.floor_div(edge_feats[name], num_trainers_per_machine)
assert np.all(F.asnumpy(part_ids) == i)

# Check the metadata
assert gpb._num_nodes() == g.number_of_nodes()
Expand Down Expand Up @@ -355,13 +379,18 @@ def test_partition():
g = create_random_graph(1000)
check_partition(g, 'metis', False)
check_partition(g, 'metis', True)
check_partition(g, 'metis', True, 4, 8)
check_partition(g, 'metis', True, 1, 8)
check_partition(g, 'random', False)
check_partition(g, 'random', True)

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_hetero_partition():
hg = create_random_hetero()
check_hetero_partition(hg, 'metis')
check_hetero_partition(hg, 'metis', 1, 8)
check_hetero_partition(hg, 'metis', 4, 8)
check_hetero_partition(hg, 'random')


Expand Down

0 comments on commit 0884d02

Please sign in to comment.