Skip to content

Commit

Permalink
[KG] reduce memory consumption. (dmlc#902)
Browse files Browse the repository at this point in the history
* reduce memory consumption.

* fix a bug.

* fix a bug.

* fix.
  • Loading branch information
zheng-da authored Dec 29, 2019
1 parent 655d756 commit f818415
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 47 deletions.
29 changes: 22 additions & 7 deletions apps/kg/dataloader/KGDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import numpy as np

def _download_and_extract(url, path, filename):
import shutil, zipfile
Expand Down Expand Up @@ -71,13 +72,20 @@ def __init__(self, path, name):

def read_triple(self, path, mode):
# mode: train/valid/test
triples = []
heads = []
tails = []
rels = []
with open(os.path.join(path, '{}.txt'.format(mode))) as f:
for line in f:
h, r, t = line.strip().split('\t')
triples.append((self.entity2id[h], self.relation2id[r], self.entity2id[t]))
heads.append(self.entity2id[h])
rels.append(self.relation2id[r])
tails.append(self.entity2id[t])
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)

return triples
return (heads, rels, tails)


class KGDataset2:
Expand Down Expand Up @@ -115,16 +123,23 @@ def __init__(self, path, name):
self.test = self.read_triple(self.path, 'test')

def read_triple(self, path, mode, skip_first_line=False):
triples = []
heads = []
tails = []
rels = []
print('Reading {} triples....'.format(mode))
with open(os.path.join(path, '{}.txt'.format(mode))) as f:
if skip_first_line:
_ = f.readline()
for line in f:
h, t, r = line.strip().split('\t')
triples.append((int(h), int(r), int(t)))
print('Finished. Read {} {} triples.'.format(len(triples), mode))
return triples
heads.append(int(h))
tails.append(int(t))
rels.append(int(r))
heads = np.array(heads, dtype=np.int64)
tails = np.array(tails, dtype=np.int64)
rels = np.array(rels, dtype=np.int64)
print('Finished. Read {} {} triples.'.format(len(heads), mode))
return (heads, rels, tails)


def get_dataset(data_path, data_name, format_str):
Expand Down
83 changes: 43 additions & 40 deletions apps/kg/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# This partitions a list of edges based on relations to make sure
# each partition has roughly the same number of edges and relations.
def RelationPartition(edges, n):
print('relation partition {} edges into {} parts'.format(len(edges), n))
rel = np.array([r for h, r, t in edges])
uniq, cnts = np.unique(rel, return_counts=True)
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
uniq = uniq[idx]
Expand All @@ -30,35 +30,39 @@ def RelationPartition(edges, n):
rel_cnts[idx] += 1
for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))

parts = []
for _ in range(n):
parts.append([])
for h, r, t in edges:
idx = rel_dict[r]
parts[idx].append((h, r, t))
# let's store the edge index to each partition first.
for i, r in enumerate(rels):
part_idx = rel_dict[r]
parts[part_idx].append(i)
for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
return parts

def RandomPartition(edges, n):
print('random partition {} edges into {} parts'.format(len(edges), n))
idx = np.random.permutation(len(edges))
heads, rels, tails = edges
print('random partition {} edges into {} parts'.format(len(heads), n))
idx = np.random.permutation(len(heads))
part_size = int(math.ceil(len(idx) / n))
parts = []
for i in range(n):
start = part_size * i
end = min(part_size * (i + 1), len(idx))
parts.append([edges[i] for i in idx[start:end]])
parts.append(idx[start:end])
print('part {} has {} edges'.format(i, len(parts[-1])))
return parts

def ConstructGraph(edges, n_entities, i, args):
pickle_name = 'graph_train_{}.pickle'.format(i)
def ConstructGraph(edges, n_entities, args):
pickle_name = 'graph_train.pickle'
if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
g = pickle.load(graph_file)
print('Load pickled graph.')
else:
src = [t[0] for t in edges]
etype_id = [t[1] for t in edges]
dst = [t[2] for t in edges]
src, etype_id, dst = edges
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[n_entities, n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes())
Expand All @@ -71,26 +75,23 @@ def ConstructGraph(edges, n_entities, i, args):
class TrainDataset(object):
def __init__(self, dataset, args, weighting=False, ranks=64):
triples = dataset.train
print('|Train|:', len(triples))
self.g = ConstructGraph(triples, dataset.n_entities, args)
num_train = len(triples[0])
print('|Train|:', num_train)
if ranks > 1 and args.rel_part:
triples_list = RelationPartition(triples, ranks)
self.edge_parts = RelationPartition(triples, ranks)
elif ranks > 1:
triples_list = RandomPartition(triples, ranks)
self.edge_parts = RandomPartition(triples, ranks)
else:
triples_list = [triples]
self.graphs = []
for i, triples in enumerate(triples_list):
g = ConstructGraph(triples, dataset.n_entities, i, args)
if weighting:
# TODO: weight to be added
count = self.count_freq(triples)
subsampling_weight = np.vectorize(
lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)]))
)
weight = subsampling_weight(src, etype_id, dst)
g.edata['weight'] = F.zerocopy_from_numpy(weight)
# to be added
self.graphs.append(g)
self.edge_parts = [np.arange(num_train)]
if weighting:
# TODO: weight to be added
count = self.count_freq(triples)
subsampling_weight = np.vectorize(
lambda h, r, t: np.sqrt(1 / (count[(h, r)] + count[(t, -r - 1)]))
)
weight = subsampling_weight(src, etype_id, dst)
self.g.edata['weight'] = F.zerocopy_from_numpy(weight)

def count_freq(self, triples, start=4):
count = {}
Expand All @@ -109,7 +110,8 @@ def count_freq(self, triples, start=4):
def create_sampler(self, batch_size, neg_sample_size=2, mode='head', num_workers=5,
shuffle=True, exclude_positive=False, rank=0):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
return EdgeSampler(self.graphs[rank],
return EdgeSampler(self.g,
seed_edges=F.tensor(self.edge_parts[rank]),
batch_size=batch_size,
neg_sample_size=neg_sample_size,
negative_mode=mode,
Expand All @@ -118,6 +120,7 @@ def create_sampler(self, batch_size, neg_sample_size=2, mode='head', num_workers
exclude_positive=exclude_positive,
return_false_neg=False)


class PBGNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head):
Expand Down Expand Up @@ -203,17 +206,17 @@ def reset(self):

class EvalDataset(object):
def __init__(self, dataset, args):
triples = dataset.train + dataset.valid + dataset.test
pickle_name = 'graph_all.pickle'
if args.pickle_graph and os.path.exists(os.path.join(args.data_path, args.dataset, pickle_name)):
with open(os.path.join(args.data_path, args.dataset, pickle_name), 'rb') as graph_file:
g = pickle.load(graph_file)
print('Load pickled graph.')
else:
src = [t[0] for t in triples]
etype_id = [t[1] for t in triples]
dst = [t[2] for t in triples]
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)), shape=[dataset.n_entities, dataset.n_entities])
src = np.concatenate((dataset.train[0], dataset.valid[0], dataset.test[0]))
etype_id = np.concatenate((dataset.train[1], dataset.valid[1], dataset.test[1]))
dst = np.concatenate((dataset.train[2], dataset.valid[2], dataset.test[2]))
coo = sp.sparse.coo_matrix((np.ones(len(src)), (src, dst)),
shape=[dataset.n_entities, dataset.n_entities])
g = dgl.DGLGraph(coo, readonly=True, sort_csr=True)
g.ndata['id'] = F.arange(0, g.number_of_nodes())
g.edata['id'] = F.tensor(etype_id, F.int64)
Expand All @@ -222,9 +225,9 @@ def __init__(self, dataset, args):
pickle.dump(g, graph_file)
self.g = g

self.num_train = len(dataset.train)
self.num_valid = len(dataset.valid)
self.num_test = len(dataset.test)
self.num_train = len(dataset.train[0])
self.num_valid = len(dataset.valid[0])
self.num_test = len(dataset.test[0])

if args.eval_percent < 1:
self.valid = np.random.randint(0, self.num_valid,
Expand Down

0 comments on commit f818415

Please sign in to comment.