Skip to content

Commit

Permalink
[Dist] enable to partition many chunks into less partitions via pipel…
Browse files Browse the repository at this point in the history
…ine (dmlc#4620)

* [Dist] enable to partition many chunks into less partitions via pipeline

* refine

* add meta file for num_parts, add more tests, refine docstring

* remove args.num_parts

* create pydantic class for partition metadata

* refine

* rename json file
  • Loading branch information
Rhett-Ying authored Sep 28, 2022
1 parent 6c1500d commit cf19254
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 72 deletions.
52 changes: 26 additions & 26 deletions tests/tools/test_dist_part.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import argparse
import dgl
import json
import numpy as np
import os
import sys
import tempfile
import torch

import pytest, unittest
import dgl
from dgl.data.utils import load_tensors, load_graphs

from chunk_graph import chunk_graph

def test_part_pipeline():
@pytest.mark.parametrize("num_chunks", [1, 2, 3, 4, 8])
@pytest.mark.parametrize("num_parts", [1, 2, 3, 4, 8])
def test_part_pipeline(num_chunks, num_parts):
if num_chunks < num_parts:
# num_parts should less/equal than num_chunks
return

# Step0: prepare chunked graph data format

# A synthetic mini MAG240
num_institutions = 20
num_authors = 100
num_papers = 600
num_institutions = 1200
num_authors = 1200
num_papers = 1200

def rand_edges(num_src, num_dst, num_edges):
eids = np.random.choice(num_src * num_dst, num_edges, replace=False)
Expand All @@ -26,9 +29,9 @@ def rand_edges(num_src, num_dst, num_edges):

return src, dst

num_cite_edges = 2000
num_write_edges = 1000
num_affiliate_edges = 200
num_cite_edges = 24 * 1000
num_write_edges = 12 * 1000
num_affiliate_edges = 2400

# Structure
data_dict = {
Expand Down Expand Up @@ -85,7 +88,6 @@ def rand_edges(num_src, num_dst, num_edges):
np.save(f, write_year)

output_dir = os.path.join(root_dir, 'chunked-data')
num_chunks = 2
chunk_graph(
g,
'mag240m',
Expand Down Expand Up @@ -159,23 +161,23 @@ def rand_edges(num_src, num_dst, num_edges):

# Step1: graph partition
in_dir = os.path.join(root_dir, 'chunked-data')
output_dir = os.path.join(root_dir, '2parts')
output_dir = os.path.join(root_dir, 'parted_data')
os.system('python3 tools/partition_algo/random_partition.py '\
'--in_dir {} --out_dir {} --num_partitions {}'.format(
in_dir, output_dir, num_chunks))
in_dir, output_dir, num_parts))
for ntype in ['author', 'institution', 'paper']:
fname = os.path.join(output_dir, '{}.txt'.format(ntype))
with open(fname, 'r') as f:
header = f.readline().rstrip()
assert isinstance(int(header), int)

# Step2: data dispatch
partition_dir = os.path.join(root_dir, '2parts')
partition_dir = os.path.join(root_dir, 'parted_data')
out_dir = os.path.join(root_dir, 'partitioned')
ip_config = os.path.join(root_dir, 'ip_config.txt')
with open(ip_config, 'w') as f:
f.write('127.0.0.1\n')
f.write('127.0.0.2\n')
for i in range(num_parts):
f.write(f'127.0.0.{i + 1}\n')

cmd = 'python3 tools/dispatch_data.py'
cmd += f' --in-dir {in_dir}'
Expand All @@ -195,19 +197,19 @@ def rand_edges(num_src, num_dst, num_edges):

all_etypes = ['affiliated_with', 'writes', 'cites', 'rev_writes']
for etype in all_etypes:
assert len(meta_data['edge_map'][etype]) == num_chunks
assert len(meta_data['edge_map'][etype]) == num_parts
assert meta_data['etypes'].keys() == set(all_etypes)
assert meta_data['graph_name'] == 'mag240m'

all_ntypes = ['author', 'institution', 'paper']
for ntype in all_ntypes:
assert len(meta_data['node_map'][ntype]) == num_chunks
assert len(meta_data['node_map'][ntype]) == num_parts
assert meta_data['ntypes'].keys() == set(all_ntypes)
assert meta_data['num_edges'] == 4200
assert meta_data['num_nodes'] == 720
assert meta_data['num_parts'] == num_chunks
assert meta_data['num_edges'] == g.num_edges()
assert meta_data['num_nodes'] == g.num_nodes()
assert meta_data['num_parts'] == num_parts

for i in range(num_chunks):
for i in range(num_parts):
sub_dir = 'part-' + str(i)
assert meta_data[sub_dir]['node_feats'] == 'part{}/node_feat.dgl'.format(i)
assert meta_data[sub_dir]['edge_feats'] == 'part{}/edge_feat.dgl'.format(i)
Expand Down Expand Up @@ -251,5 +253,3 @@ def rand_edges(num_src, num_dst, num_edges):
orig_eids = load_tensors(fname)
assert len(orig_eids.keys()) == 4

if __name__ == '__main__':
test_part_pipeline()
20 changes: 17 additions & 3 deletions tools/dispatch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
import logging
import json
from partition_algo.base import load_partition_meta

INSTALL_DIR = os.path.abspath(os.path.join(__file__, '..'))
LAUNCH_SCRIPT = "distgraphlaunch.py"
Expand Down Expand Up @@ -33,16 +34,29 @@ def get_launch_cmd(args) -> str:


def submit_jobs(args) -> str:
wrapper_command = os.path.join(INSTALL_DIR, LAUNCH_SCRIPT)

#read the json file and get the remaining argument here.
schema_path = "metadata.json"
with open(os.path.join(args.in_dir, schema_path)) as schema:
schema_map = json.load(schema)

num_parts = len(schema_map["num_nodes_per_chunk"][0])
graph_name = schema_map["graph_name"]

# retrieve num_parts
num_chunks = len(schema_map["num_nodes_per_chunk"][0])
num_parts = num_chunks
partition_path = os.path.join(args.partitions_dir, "partition_meta.json")
if os.path.isfile(partition_path):
part_meta = load_partition_meta(partition_path)
num_parts = part_meta.num_parts
if num_parts > num_chunks:
raise Exception('Number of partitions should be less/equal than number of chunks.')

# verify ip_config
with open(args.ip_config, 'r') as f:
num_ips = len(f.readlines())
assert num_ips == num_parts, \
f'The number of lines[{args.ip_config}] should be equal to num_parts[{num_parts}].'

argslist = ""
argslist += "--world-size {} ".format(num_parts)
argslist += "--partitions-dir {} ".format(os.path.abspath(args.partitions_dir))
Expand Down
8 changes: 4 additions & 4 deletions tools/distpartitioning/convert_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
"""
#create auxiliary data structures from the schema object
memory_snapshot("CreateDGLObj_Begin", part_id)
ntid_dict, global_nid_ranges = get_idranges(schema[constants.STR_NODE_TYPE],
schema[constants.STR_NUM_NODES_PER_CHUNK])
_, global_nid_ranges = get_idranges(schema[constants.STR_NODE_TYPE],
schema[constants.STR_NUM_NODES_PER_CHUNK])
memory_snapshot("CreateDGLObj_Begin", part_id)

etid_dict, global_eid_ranges = get_idranges(schema[constants.STR_EDGE_TYPE],
_, global_eid_ranges = get_idranges(schema[constants.STR_EDGE_TYPE],
schema[constants.STR_NUM_EDGES_PER_CHUNK])

id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)
Expand All @@ -119,7 +120,6 @@ def create_dgl_object(schema, part_id, node_data, edge_data, edgeid_offset,
ntypes_map = {e: i for i, e in enumerate(ntypes)}
etypes = [(key, global_eid_ranges[key][0, 0]) for key in global_eid_ranges]
etypes.sort(key=lambda e: e[1])
etype_offset_np = np.array([e[1] for e in etypes])
etypes = [e[0] for e in etypes]
etypes_map = {e.split(":")[1]: i for i, e in enumerate(etypes)}

Expand Down
9 changes: 5 additions & 4 deletions tools/distpartitioning/data_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def gen_node_data(rank, world_size, id_lookup, ntid_ntype_map, schema_map):
}

type_nid_dict, global_nid_dict = get_idranges(schema_map[constants.STR_NODE_TYPE],
schema_map[constants.STR_NUM_NODES_PER_CHUNK])
schema_map[constants.STR_NUM_NODES_PER_CHUNK],
num_chunks=world_size)

for ntype_id, ntype_name in ntid_ntype_map.items():
type_start, type_end = type_nid_dict[ntype_name][0][0], type_nid_dict[ntype_name][-1][1]
Expand Down Expand Up @@ -290,7 +291,7 @@ def exchange_node_features(rank, world_size, node_feature_tids, ntype_gnid_map,
return own_node_features, own_global_nids

def exchange_graph_data(rank, world_size, node_features, node_feat_tids, edge_data,
id_lookup, ntypes_ntypeid_map, ntypes_gnid_range_map, ntid_ntype_map, schema_map):
id_lookup, ntypes_gnid_range_map, ntid_ntype_map, schema_map):
"""
Wrapper function which is used to shuffle graph data on all the processes.
Expand Down Expand Up @@ -556,8 +557,8 @@ def gen_dist_partitions(rank, world_size, params):
#and return the aggregated data
ntypes_gnid_range_map = get_gnid_range_map(node_tids)
node_data, rcvd_node_features, rcvd_global_nids, edge_data = \
exchange_graph_data(rank, world_size, node_features, node_feat_tids, \
edge_data, id_lookup, ntypes_ntypeid_map, ntypes_gnid_range_map, \
exchange_graph_data(rank, world_size, node_features, node_feat_tids,
edge_data, id_lookup, ntypes_gnid_range_map,
ntypeid_ntypes_map, schema_map)
gc.collect()
logging.info(f'[Rank: {rank}] Done with data shuffling...')
Expand Down
58 changes: 36 additions & 22 deletions tools/distpartitioning/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,18 @@ def get_dataset(input_dir, graph_name, rank, world_size, schema_map):
#where key: feature_name, value: dictionary in which keys are "format", "data"
node_feature_tids[ntype_name] = []
for feat_name, feat_data in ntype_feature_data.items():
assert len(feat_data[constants.STR_DATA]) == world_size
assert feat_data[constants.STR_FORMAT][constants.STR_NAME] == constants.STR_NUMPY
my_feat_data_fname = feat_data[constants.STR_DATA][rank] #this will be just the file name
if (os.path.isabs(my_feat_data_fname)):
logging.info(f'Loading numpy from {my_feat_data_fname}')
node_features[ntype_name+'/'+feat_name] = \
torch.from_numpy(np.load(my_feat_data_fname))
else:
numpy_path = os.path.join(input_dir, my_feat_data_fname)
logging.info(f'Loading numpy from {numpy_path}')
node_features[ntype_name+'/'+feat_name] = \
torch.from_numpy(np.load(numpy_path))
num_chunks = len(feat_data[constants.STR_DATA])
read_list = np.array_split(np.arange(num_chunks), world_size)
nfeat = []
for idx in read_list[rank]:
nfeat_file = feat_data[constants.STR_DATA][idx]
if not os.path.isabs(nfeat_file):
nfeat_file = os.path.join(input_dir, nfeat_file)
logging.info(f'Loading node feature[{feat_name}] of ntype[{ntype_name}] from {nfeat_file}')
nfeat.append(np.load(nfeat_file))
nfeat = np.concatenate(nfeat)
node_features[ntype_name + '/' + feat_name] = torch.from_numpy(nfeat)

node_feature_tids[ntype_name].append([feat_name, -1, -1])

Expand Down Expand Up @@ -152,7 +152,8 @@ def get_dataset(input_dir, graph_name, rank, world_size, schema_map):

#read my nodes for each node type
node_tids, ntype_gnid_offset = get_idranges(schema_map[constants.STR_NODE_TYPE],
schema_map[constants.STR_NUM_NODES_PER_CHUNK])
schema_map[constants.STR_NUM_NODES_PER_CHUNK],
num_chunks=world_size)
for ntype_name in schema_map[constants.STR_NODE_TYPE]:
if ntype_name in node_feature_tids:
for item in node_feature_tids[ntype_name]:
Expand Down Expand Up @@ -211,8 +212,9 @@ def get_dataset(input_dir, graph_name, rank, world_size, schema_map):
#read my edges for each edge type
etype_names = schema_map[constants.STR_EDGE_TYPE]
etype_name_idmap = {e : idx for idx, e in enumerate(etype_names)}
edge_tids, _ = get_idranges(schema_map[constants.STR_EDGE_TYPE],
schema_map[constants.STR_NUM_EDGES_PER_CHUNK])
edge_tids, _ = get_idranges(schema_map[constants.STR_EDGE_TYPE],
schema_map[constants.STR_NUM_EDGES_PER_CHUNK],
num_chunks=world_size)

edge_datadict = {}
edge_data = schema_map[constants.STR_EDGES]
Expand All @@ -226,26 +228,38 @@ def get_dataset(input_dir, graph_name, rank, world_size, schema_map):
assert etype_info[constants.STR_FORMAT][constants.STR_NAME] == constants.STR_CSV

edge_info = etype_info[constants.STR_DATA]
assert len(edge_info) == world_size

#edgetype strings are in canonical format, src_node_type:edge_type:dst_node_type
tokens = etype_name.split(":")
assert len(tokens) == 3

src_ntype_name = tokens[0]
rel_name = tokens[1]
dst_ntype_name = tokens[2]

logging.info(f'Reading csv files from {edge_info[rank]}')
data_df = csv.read_csv(edge_info[rank], read_options=pyarrow.csv.ReadOptions(autogenerate_column_names=True),
parse_options=pyarrow.csv.ParseOptions(delimiter=' '))
num_chunks = len(edge_info)
read_list = np.array_split(np.arange(num_chunks), world_size)
src_ids = []
dst_ids = []
for idx in read_list[rank]:
edge_file = edge_info[idx]
if not os.path.isabs(edge_file):
edge_file = os.path.join(input_dir, edge_file)
logging.info(f'Loading edges of etype[{etype_name}] from {edge_file}')
data_df = csv.read_csv(edge_file,
read_options=pyarrow.csv.ReadOptions(autogenerate_column_names=True),
parse_options=pyarrow.csv.ParseOptions(delimiter=' '))
src_ids.append(data_df['f0'].to_numpy())
dst_ids.append(data_df['f1'].to_numpy())
src_ids = np.concatenate(src_ids)
dst_ids = np.concatenate(dst_ids)

#currently these are just type_edge_ids... which will be converted to global ids
edge_datadict[constants.GLOBAL_SRC_ID].append(data_df['f0'].to_numpy() + ntype_gnid_offset[src_ntype_name][0, 0])
edge_datadict[constants.GLOBAL_DST_ID].append(data_df['f1'].to_numpy() + ntype_gnid_offset[dst_ntype_name][0, 0])
edge_datadict[constants.GLOBAL_SRC_ID].append(src_ids + ntype_gnid_offset[src_ntype_name][0, 0])
edge_datadict[constants.GLOBAL_DST_ID].append(dst_ids + ntype_gnid_offset[dst_ntype_name][0, 0])
edge_datadict[constants.GLOBAL_TYPE_EID].append(np.arange(edge_tids[etype_name][rank][0],\
edge_tids[etype_name][rank][1] ,dtype=np.int64))
edge_datadict[constants.ETYPE_ID].append(etype_name_idmap[etype_name] * \
np.ones(shape=(data_df['f0'].to_numpy().shape), dtype=np.int64))
np.ones(shape=(src_ids.shape), dtype=np.int64))

#stitch together to create the final data on the local machine
for col in [constants.GLOBAL_SRC_ID, constants.GLOBAL_DST_ID, constants.GLOBAL_TYPE_EID, constants.ETYPE_ID]:
Expand Down
37 changes: 24 additions & 13 deletions tools/distpartitioning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,9 @@ def get_ntype_featnames(ntype_name, schema_map):
list :
a list of feature names for a given node_type
"""
ntype_featdict = schema_map[constants.STR_NODE_DATA]
if (ntype_name in ntype_featdict):
featnames = []
ntype_info = ntype_featdict[ntype_name]
for k, v in ntype_info.items():
featnames.append(k)
return featnames
else:
return []
node_data = schema_map[constants.STR_NODE_DATA]
feats = node_data.get(ntype_name, {})
return [feat for feat in feats]

def get_node_types(schema_map):
"""
Expand Down Expand Up @@ -393,7 +387,7 @@ def write_dgl_objects(graph_obj, node_features, edge_features,
orig_eids_file = os.path.join(part_dir, 'orig_eids.dgl')
dgl.data.utils.save_tensors(orig_eids_file, orig_eids)

def get_idranges(names, counts):
def get_idranges(names, counts, num_chunks=None):
"""
Utility function to compute typd_id/global_id ranges for both nodes and edges.
Expand All @@ -403,6 +397,11 @@ def get_idranges(names, counts):
list of node/edge types as strings
counts : list of lists
each list contains no. of nodes/edges in a given chunk
num_chunks : int, optional
In distributed partition pipeline, ID ranges are grouped into chunks.
In some scenarios, we'd like to merge ID ranges into specific number
of chunks. This parameter indicates the expected number of chunks.
If not specified, no merge is applied.
Returns:
--------
Expand All @@ -418,21 +417,33 @@ def get_idranges(names, counts):
gnid_end = gnid_start
tid_dict = {}
gid_dict = {}
orig_num_chunks = 0
for idx, typename in enumerate(names):
type_counts = counts[idx]
tid_start = np.cumsum([0] + type_counts[:-1])
tid_end = np.cumsum(type_counts)
tid_ranges = list(zip(tid_start, tid_end))

type_start = tid_ranges[0][0]
type_end = tid_ranges[-1][1]

gnid_end += tid_ranges[-1][1]

tid_dict[typename] = tid_ranges
gid_dict[typename] = np.array([gnid_start, gnid_end]).reshape([1,2])

gnid_start = gnid_end
orig_num_chunks = len(tid_start)

if num_chunks is None:
return tid_dict, gid_dict

assert num_chunks <= orig_num_chunks, \
'Specified number of chunks should be less/euqual than original numbers of ID ranges.'
chunk_list = np.array_split(np.arange(orig_num_chunks), num_chunks)
for typename in tid_dict:
orig_tid_ranges = tid_dict[typename]
tid_ranges = []
for idx in chunk_list:
tid_ranges.append((orig_tid_ranges[idx[0]][0], orig_tid_ranges[idx[-1]][-1]))
tid_dict[typename] = tid_ranges

return tid_dict, gid_dict

Expand Down
Loading

0 comments on commit cf19254

Please sign in to comment.