Skip to content

Commit

Permalink
[KG] PBG's way of constructing negative edges (dmlc#1159)
Browse files Browse the repository at this point in the history
* attach positive.

* add neg_deg_sample.

* add comment.

* add neg_deg_sample for eval.

* change the edge sampler.

* rename edge sampler in KG.

* allow specifying chunk size and negative sample size separately.

* fix bugs in KG.

* add check in sampler.

* add more checks.

* fix

* add comment.

* add comments.
  • Loading branch information
zheng-da authored Jan 5, 2020
1 parent 1de192f commit 1022d5d
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 83 deletions.
49 changes: 29 additions & 20 deletions apps/kg/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,25 @@ def count_freq(self, triples, start=4):
count[(tail, -rel - 1)] += 1
return count

def create_sampler(self, batch_size, neg_sample_size=2, mode='head', num_workers=5,
def create_sampler(self, batch_size, neg_sample_size=2, neg_chunk_size=None, mode='head', num_workers=5,
shuffle=True, exclude_positive=False, rank=0):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
return EdgeSampler(self.g,
seed_edges=F.tensor(self.edge_parts[rank]),
batch_size=batch_size,
neg_sample_size=neg_sample_size,
chunk_size=neg_chunk_size,
negative_mode=mode,
num_workers=num_workers,
shuffle=shuffle,
exclude_positive=exclude_positive,
return_false_neg=False)


class PBGNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
class ChunkNegEdgeSubgraph(dgl.subgraph.DGLSubGraph):
def __init__(self, subg, num_chunks, chunk_size,
neg_sample_size, neg_head):
super(PBGNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
super(ChunkNegEdgeSubgraph, self).__init__(subg._parent, subg.sgi)
self.subg = subg
self.num_chunks = num_chunks
self.chunk_size = chunk_size
Expand All @@ -140,7 +141,11 @@ def tail_nid(self):
return self.subg.tail_nid


def create_neg_subgraph(pos_g, neg_g, is_pbg, neg_head, num_nodes):
# KG models need to know the number of chunks, the chunk size and negative sample size
# of a negative subgraph to perform the computation more efficiently.
# This function tries to infer all of these information of the negative subgraph
# and create a wrapper class that contains all of the information.
def create_neg_subgraph(pos_g, neg_g, chunk_size, is_chunked, neg_head, num_nodes):
assert neg_g.number_of_edges() % pos_g.number_of_edges() == 0
neg_sample_size = int(neg_g.number_of_edges() / pos_g.number_of_edges())
# We use all nodes to create negative edges. Regardless of the sampling algorithm,
Expand All @@ -149,30 +154,32 @@ def create_neg_subgraph(pos_g, neg_g, is_pbg, neg_head, num_nodes):
or (not neg_head and len(neg_g.tail_nid) == num_nodes):
num_chunks = 1
chunk_size = pos_g.number_of_edges()
elif is_pbg:
if pos_g.number_of_edges() < neg_sample_size:
elif is_chunked:
if pos_g.number_of_edges() < chunk_size:
num_chunks = 1
chunk_size = pos_g.number_of_edges()
else:
# This is probably the last batch. Let's ignore it.
if pos_g.number_of_edges() % neg_sample_size > 0:
if pos_g.number_of_edges() % chunk_size > 0:
return None
num_chunks = int(pos_g.number_of_edges()/ neg_sample_size)
chunk_size = neg_sample_size
num_chunks = int(pos_g.number_of_edges()/ chunk_size)
assert num_chunks * chunk_size == pos_g.number_of_edges()
assert num_chunks * neg_sample_size * chunk_size == neg_g.number_of_edges()
else:
num_chunks = pos_g.number_of_edges()
chunk_size = 1
return PBGNegEdgeSubgraph(neg_g, num_chunks, chunk_size,
neg_sample_size, neg_head)
return ChunkNegEdgeSubgraph(neg_g, num_chunks, chunk_size,
neg_sample_size, neg_head)

class EvalSampler(object):
def __init__(self, g, edges, batch_size, neg_sample_size, mode, num_workers,
def __init__(self, g, edges, batch_size, neg_sample_size, neg_chunk_size, mode, num_workers,
filter_false_neg):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
self.sampler = EdgeSampler(g,
batch_size=batch_size,
seed_edges=edges,
neg_sample_size=neg_sample_size,
chunk_size=neg_chunk_size,
negative_mode=mode,
num_workers=num_workers,
shuffle=False,
Expand All @@ -184,6 +191,7 @@ def __init__(self, g, edges, batch_size, neg_sample_size, mode, num_workers,
self.neg_head = 'head' in mode
self.g = g
self.filter_false_neg = filter_false_neg
self.neg_chunk_size = neg_chunk_size

def __iter__(self):
return self
Expand All @@ -193,7 +201,7 @@ def __next__(self):
pos_g, neg_g = next(self.sampler_iter)
if self.filter_false_neg:
neg_positive = neg_g.edata['false_neg']
neg_g = create_neg_subgraph(pos_g, neg_g, 'PBG' in self.mode,
neg_g = create_neg_subgraph(pos_g, neg_g, self.neg_chunk_size, 'chunk' in self.mode,
self.neg_head, self.g.number_of_nodes())
if neg_g is not None:
break
Expand Down Expand Up @@ -280,22 +288,22 @@ def check(self, eval_type):
np.testing.assert_equal(F.asnumpy(dst_id), orig_dst)
np.testing.assert_equal(F.asnumpy(etype), orig_etype)

def create_sampler(self, eval_type, batch_size, neg_sample_size,
def create_sampler(self, eval_type, batch_size, neg_sample_size, neg_chunk_size,
filter_false_neg, mode='head', num_workers=5, rank=0, ranks=1):
edges = self.get_edges(eval_type)
beg = edges.shape[0] * rank // ranks
end = min(edges.shape[0] * (rank + 1) // ranks, edges.shape[0])
edges = edges[beg: end]
return EvalSampler(self.g, edges, batch_size, neg_sample_size,
return EvalSampler(self.g, edges, batch_size, neg_sample_size, neg_chunk_size,
mode, num_workers, filter_false_neg)

class NewBidirectionalOneShotIterator:
def __init__(self, dataloader_head, dataloader_tail, is_pbg, num_nodes):
def __init__(self, dataloader_head, dataloader_tail, neg_chunk_size, is_chunked, num_nodes):
self.sampler_head = dataloader_head
self.sampler_tail = dataloader_tail
self.iterator_head = self.one_shot_iterator(dataloader_head, is_pbg,
self.iterator_head = self.one_shot_iterator(dataloader_head, neg_chunk_size, is_chunked,
True, num_nodes)
self.iterator_tail = self.one_shot_iterator(dataloader_tail, is_pbg,
self.iterator_tail = self.one_shot_iterator(dataloader_tail, neg_chunk_size, is_chunked,
False, num_nodes)
self.step = 0

Expand All @@ -308,10 +316,11 @@ def __next__(self):
return pos_g, neg_g

@staticmethod
def one_shot_iterator(dataloader, is_pbg, neg_head, num_nodes):
def one_shot_iterator(dataloader, neg_chunk_size, is_chunked, neg_head, num_nodes):
while True:
for pos_g, neg_g in dataloader:
neg_g = create_neg_subgraph(pos_g, neg_g, is_pbg, neg_head, num_nodes)
neg_g = create_neg_subgraph(pos_g, neg_g, neg_chunk_size, is_chunked,
neg_head, num_nodes)
if neg_g is None:
continue

Expand Down
26 changes: 21 additions & 5 deletions apps/kg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self):
help='batch size used for eval and test')
self.add_argument('--neg_sample_size', type=int, default=-1,
help='negative sampling size for testing')
self.add_argument('--neg_deg_sample', action='store_true',
help='negative sampling proportional to vertex degree for testing')
self.add_argument('--neg_chunk_size', type=int, default=-1,
help='chunk size of the negative edges.')
self.add_argument('--hidden_dim', type=int, default=256,
help='hidden dim used by relation and entity')
self.add_argument('-g', '--gamma', type=float, default=12.0,
Expand Down Expand Up @@ -86,6 +90,10 @@ def get_logger(args):
return logger

def main(args):
args.eval_filter = not args.no_eval_filter
if args.neg_deg_sample:
assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

# load dataset and samplers
dataset = get_dataset(args.data_path, args.dataset, args.format)
args.pickle_graph = False
Expand All @@ -98,10 +106,14 @@ def main(args):
# Here we want to use the regualr negative sampler because we need to ensure that
# all positive edges are excluded.
eval_dataset = EvalDataset(dataset, args)

args.neg_sample_size_test = args.neg_sample_size
args.neg_deg_sample_eval = args.neg_deg_sample
if args.neg_sample_size < 0:
args.neg_sample_size_test = args.neg_sample_size = eval_dataset.g.number_of_nodes()
args.eval_filter = not args.no_eval_filter
if args.neg_chunk_size < 0:
args.neg_chunk_size = args.neg_sample_size

num_workers = args.num_worker
# for multiprocessing evaluation, we don't need to sample multiple batches at a time
# in each process.
Expand All @@ -113,29 +125,33 @@ def main(args):
for i in range(args.num_proc):
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
args.eval_filter,
mode='PBG-head',
mode='chunk-head',
num_workers=num_workers,
rank=i, ranks=args.num_proc)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
args.eval_filter,
mode='PBG-tail',
mode='chunk-tail',
num_workers=num_workers,
rank=i, ranks=args.num_proc)
test_sampler_heads.append(test_sampler_head)
test_sampler_tails.append(test_sampler_tail)
else:
test_sampler_head = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
args.eval_filter,
mode='PBG-head',
mode='chunk-head',
num_workers=num_workers,
rank=0, ranks=1)
test_sampler_tail = eval_dataset.create_sampler('test', args.batch_size,
args.neg_sample_size,
args.neg_chunk_size,
args.eval_filter,
mode='PBG-tail',
mode='chunk-tail',
num_workers=num_workers,
rank=0, ranks=1)

Expand Down
48 changes: 41 additions & 7 deletions apps/kg/models/general_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,36 +94,67 @@ def predict_score(self, g):
self.score_func(g)
return g.edata['score']

def predict_neg_score(self, pos_g, neg_g, to_device=None, gpu_id=-1, trace=False):
def predict_neg_score(self, pos_g, neg_g, to_device=None, gpu_id=-1, trace=False,
neg_deg_sample=False):
num_chunks = neg_g.num_chunks
chunk_size = neg_g.chunk_size
neg_sample_size = neg_g.neg_sample_size
mask = F.ones((num_chunks, chunk_size * (neg_sample_size + chunk_size)),
dtype=F.float32, ctx=F.context(pos_g.ndata['emb']))
if neg_g.neg_head:
neg_head_ids = neg_g.ndata['id'][neg_g.head_nid]
neg_head = self.entity_emb(neg_head_ids, gpu_id, trace)
_, tail_ids = pos_g.all_edges(order='eid')
head_ids, tail_ids = pos_g.all_edges(order='eid')
if to_device is not None and gpu_id >= 0:
tail_ids = to_device(tail_ids, gpu_id)
tail = pos_g.ndata['emb'][tail_ids]
rel = pos_g.edata['emb']

# When we train a batch, we could use the head nodes of the positive edges to
# construct negative edges. We construct a negative edge between a positive head
# node and every positive tail node.
# When we construct negative edges like this, we know there is one positive
# edge for a positive head node among the negative edges. We need to mask
# them.
if neg_deg_sample:
head = pos_g.ndata['emb'][head_ids]
head = head.reshape(num_chunks, chunk_size, -1)
neg_head = neg_head.reshape(num_chunks, neg_sample_size, -1)
neg_head = F.cat([head, neg_head], 1)
neg_sample_size = chunk_size + neg_sample_size
mask[:,0::(neg_sample_size + 1)] = 0
neg_head = neg_head.reshape(num_chunks * neg_sample_size, -1)
neg_head, tail = self.head_neg_prepare(pos_g.edata['id'], num_chunks, neg_head, tail, gpu_id, trace)
neg_score = self.head_neg_score(neg_head, rel, tail,
num_chunks, chunk_size, neg_sample_size)
else:
neg_tail_ids = neg_g.ndata['id'][neg_g.tail_nid]
neg_tail = self.entity_emb(neg_tail_ids, gpu_id, trace)
head_ids, _ = pos_g.all_edges(order='eid')
head_ids, tail_ids = pos_g.all_edges(order='eid')
if to_device is not None and gpu_id >= 0:
head_ids = to_device(head_ids, gpu_id)
head = pos_g.ndata['emb'][head_ids]
rel = pos_g.edata['emb']

# This is negative edge construction similar to the above.
if neg_deg_sample:
tail = pos_g.ndata['emb'][tail_ids]
tail = tail.reshape(num_chunks, chunk_size, -1)
neg_tail = neg_tail.reshape(num_chunks, neg_sample_size, -1)
neg_tail = F.cat([tail, neg_tail], 1)
neg_sample_size = chunk_size + neg_sample_size
mask[:,0::(neg_sample_size + 1)] = 0
neg_tail = neg_tail.reshape(num_chunks * neg_sample_size, -1)
head, neg_tail = self.tail_neg_prepare(pos_g.edata['id'], num_chunks, head, neg_tail, gpu_id, trace)
neg_score = self.tail_neg_score(head, rel, neg_tail,
num_chunks, chunk_size, neg_sample_size)

return neg_score
if neg_deg_sample:
neg_g.neg_sample_size = neg_sample_size
mask = mask.reshape(num_chunks, chunk_size, neg_sample_size)
return neg_score * mask
else:
return neg_score

def forward_test(self, pos_g, neg_g, logs, gpu_id=-1):
pos_g.ndata['emb'] = self.entity_emb(pos_g.ndata['id'], gpu_id, False)
Expand All @@ -136,7 +167,8 @@ def forward_test(self, pos_g, neg_g, logs, gpu_id=-1):
pos_scores = reshape(logsigmoid(pos_scores), batch_size, -1)

neg_scores = self.predict_neg_score(pos_g, neg_g, to_device=cuda,
gpu_id=gpu_id, trace=False)
gpu_id=gpu_id, trace=False,
neg_deg_sample=self.args.neg_deg_sample_eval)
neg_scores = reshape(logsigmoid(neg_scores), batch_size, -1)

# We need to filter the positive edges in the negative graph.
Expand Down Expand Up @@ -171,9 +203,11 @@ def forward(self, pos_g, neg_g, gpu_id=-1):
pos_score = logsigmoid(pos_score)
if gpu_id >= 0:
neg_score = self.predict_neg_score(pos_g, neg_g, to_device=cuda,
gpu_id=gpu_id, trace=True)
gpu_id=gpu_id, trace=True,
neg_deg_sample=self.args.neg_deg_sample)
else:
neg_score = self.predict_neg_score(pos_g, neg_g, trace=True)
neg_score = self.predict_neg_score(pos_g, neg_g, trace=True,
neg_deg_sample=self.args.neg_deg_sample)

neg_score = reshape(neg_score, -1, neg_g.neg_sample_size)
# Adversarial sampling
Expand Down
4 changes: 2 additions & 2 deletions apps/kg/tests/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def check_score_func(func_name):
EdgeSampler = getattr(dgl.contrib.sampling, 'EdgeSampler')
sampler = EdgeSampler(g, batch_size=batch_size,
neg_sample_size=neg_sample_size,
negative_mode='PBG-head',
negative_mode='chunk-head',
num_workers=1,
shuffle=False,
exclude_positive=False,
return_false_neg=False)

for pos_g, neg_g in sampler:
neg_g = create_neg_subgraph(pos_g, neg_g, True, True, g.number_of_nodes())
neg_g = create_neg_subgraph(pos_g, neg_g, neg_sample_size, True, True, g.number_of_nodes())
pos_g.copy_from_parent()
neg_g.copy_from_parent()
score1 = F.reshape(model.predict_score(neg_g), (batch_size, -1))
Expand Down
Loading

0 comments on commit 1022d5d

Please sign in to comment.