Skip to content

Commit

Permalink
[KG] partition a KG. (dmlc#1351)
Browse files Browse the repository at this point in the history
* partition a KG.

* get tid.

* support builtin

* fix a minor bug.

Co-authored-by: Chao Ma <[email protected]>
  • Loading branch information
zheng-da and aksnzhy authored Mar 15, 2020
1 parent 545cc06 commit eb71c80
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
3 changes: 1 addition & 2 deletions apps/kg/dataloader/KGDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ def __init__(self, path, name, files, format, read_triple=True, only_train=False
super(KGDatasetUDD, self).__init__(os.path.join(path, files[0]),
os.path.join(path, files[1]),
os.path.join(path, files[2]),
os.path.join(path, None),
os.path.join(path, None),
None, None,
format=format,
read_triple=read_triple,
only_train=only_train)
Expand Down
54 changes: 54 additions & 0 deletions apps/kg/partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from dataloader import get_dataset
import scipy as sp
import numpy as np
import argparse
import signal
import dgl
from dgl import backend as F
from dgl.data.utils import load_graphs, save_graphs

def main():
parser = argparse.ArgumentParser(description='Partition a knowledge graph')
parser.add_argument('--data_path', type=str, default='data',
help='root path of all dataset')
parser.add_argument('--dataset', type=str, default='FB15k',
help='dataset name, under data_path')
parser.add_argument('--data_files', type=str, default=None, nargs='+',
help='a list of data files, e.g. entity relation train valid test')
parser.add_argument('--format', type=str, default='built_in',
help='the format of the dataset, it can be built_in,'\
'raw_udd_{htr} and udd_{htr}')
parser.add_argument('-k', '--num-parts', required=True, type=int,
help='The number of partitions')
args = parser.parse_args()
num_parts = args.num_parts

# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format, args.data_files)

src, etype_id, dst = dataset.train
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.edata['tid'] = F.tensor(etype_id, F.int64)

part_dict = dgl.transform.metis_partition(g, num_parts, 1)

tot_num_inner_edges = 0
for part_id in part_dict:
part = part_dict[part_id]

num_inner_nodes = len(np.nonzero(F.asnumpy(part.ndata['inner_node']))[0])
num_inner_edges = len(np.nonzero(F.asnumpy(part.edata['inner_edge']))[0])
print('part {} has {} nodes and {} edges. {} nodes and {} edges are inside the partition'.format(
part_id, part.number_of_nodes(), part.number_of_edges(),
num_inner_nodes, num_inner_edges))
tot_num_inner_edges += num_inner_edges

part.copy_from_parent()
save_graphs(args.data_path + '/part_' + str(part_id) + '.dgl', [part])
print('there are {} edges in the graph and {} edge cuts for {} partitions.'.format(
g.number_of_edges(), g.number_of_edges() - tot_num_inner_edges, len(part_dict)))

if __name__ == '__main__':
main()

0 comments on commit eb71c80

Please sign in to comment.