Skip to content

Commit

Permalink
[KG][Optimization] Soft relation partition (dmlc#1252)
Browse files Browse the repository at this point in the history
* Several optimizations on DGL-KG:
1. Sorted positive edges for sampling which can reduce random
   memory access during positive sampling
2. Asynchronous node embedding update
3. Balanced Relation Partition that gives balanced number of
   edges in each partition. When there is no cross partition
   relation, relation embedding can be pin into GPU memory
4. tunable neg_sample_size instead of fixed neg_sample_size

* Fix test

* Fix test and eval.py

* Now TransR is OK

* Fix single GPU with mix_cpu_gpu

* Add app tests

* Fix test script

* fix mxnet

* Fix sample

* Add docstrings

* Fix

* Default value for num_workers

* Add soft relation part

* Upd

* Some fix

* upd

* Now work

* Fix TransR

* Fix eval and add some doc string

* triger

* upd

* Add some training scripts for freebase multi-gpu

* upd

* upd

* upd
  • Loading branch information
classicsong authored Feb 16, 2020
1 parent 7a80faf commit 49fe5b3
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 13 deletions.
17 changes: 17 additions & 0 deletions apps/kg/config/best_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,20 @@ DGLBACKEND=pytorch python3 train.py --model ComplEx --dataset Freebase --batch_s
--neg_sample_size 256 --hidden_dim 400 --gamma 500.0 --lr 0.1 --max_step 50000 \
--batch_size_eval 128 --test -adv --eval_interval 300000 \
--neg_sample_size_test 100000 --eval_percent 0.02 --num_proc 64

# Freebase multi-gpu
# TransE_l2 8 GPU
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset Freebase --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 10 --lr 0.1 --batch_size_eval 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_test 1000 \
--num_proc 8 --gpu 0 1 2 3 4 5 6 7 --num_worker 4 --regularization_coef 1e-9 \
--no_eval_filter --max_step 400000 --rel_part --eval_interval 100000 --log_interval 10000 \
--no_eval_filter --async_update --neg_deg_sample --force_sync_interval 1000

# TransE_l2 16 GPU
DGLBACKEND=pytorch python3 train.py --model TransE_l2 --dataset Freebase --batch_size 1024 \
--neg_sample_size 256 --hidden_dim 400 --gamma 10 --lr 0.1 --batch_size_eval 1000 \
--valid --test -adv --mix_cpu_gpu --neg_deg_sample_eval --neg_sample_size_test 1000 \
--num_proc 16 --gpu 0 1 2 3 4 5 6 7 --num_worker 4 --regularization_coef 1e-9 \
--no_eval_filter --max_step 200000 --soft_rel_part --eval_interval 100000 --log_interval 10000 \
--no_eval_filter --async_update --neg_deg_sample --force_sync_interval 1000
118 changes: 117 additions & 1 deletion apps/kg/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,119 @@
import pickle
import time

def SoftRelationPartition(edges, n, threshold=0.05):
"""This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the
threshold, its edges will be evenly distributed into all partitions.
For any relation with number of edges smaller than the threshold, its
edges will be put into one single partition.
Algo:
For r in relations:
if r.size() > threadold
Evenly divide edges of r into n parts and put into each relation.
else
Find partition with fewest edges, and put edges of r into
this partition.
Parameters
----------
edges : (heads, rels, tails) triple
Edge list to partition
n : int
Number of partitions
threshold : float
The threshold of whether a relation is LARGE or SMALL
Default: 5%
Returns
-------
List of np.array
Edges of each partition
List of np.array
Edge types of each partition
bool
Whether there exists some relations belongs to multiple partitions
"""
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
uniq = uniq[idx]
assert cnts[0] > cnts[-1]
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
cross_rel_part = []
for _ in range(n):
rel_parts.append([])

large_threshold = int(len(rels) * threshold)
capacity_per_partition = int(len(rels) / n)
# ensure any relation larger than the partition capacity will be split
large_threshold = capacity_per_partition if capacity_per_partition < large_threshold \
else large_threshold
num_cross_part = 0
for i in range(len(cnts)):
cnt = cnts[i]
r = uniq[i]
r_parts = []
if cnt > large_threshold:
avg_part_cnt = (cnt // n) + 1
num_cross_part += 1
for j in range(n):
part_cnt = avg_part_cnt if cnt > avg_part_cnt else cnt
r_parts.append([j, part_cnt])
rel_parts[j].append(r)
edge_cnts[j] += part_cnt
rel_cnts[j] += 1
cnt -= part_cnt
cross_rel_part.append(r)
else:
idx = np.argmin(edge_cnts)
r_parts.append([idx, cnt])
rel_parts[idx].append(r)
edge_cnts[idx] += cnt
rel_cnts[idx] += 1
rel_dict[r] = r_parts

for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts)))

parts = []
for i in range(n):
parts.append([])
rel_parts[i] = np.array(rel_parts[i])

for i, r in enumerate(rels):
r_part = rel_dict[r][0]
part_idx = r_part[0]
cnt = r_part[1]
parts[part_idx].append(i)
cnt -= 1
if cnt == 0:
rel_dict[r].pop(0)
else:
rel_dict[r][0][1] = cnt

for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
shuffle_idx = np.concatenate(parts)
heads[:] = heads[shuffle_idx]
rels[:] = rels[shuffle_idx]
tails[:] = tails[shuffle_idx]

off = 0
for i, part in enumerate(parts):
parts[i] = np.arange(off, off + len(part))
off += len(part)
cross_rel_part = np.array(cross_rel_part)

return parts, rel_parts, num_cross_part > 0, cross_rel_part

def BalancedRelationPartition(edges, n):
"""This partitions a list of edges based on relations to make sure
each partition has roughly the same number of edges and relations.
Expand Down Expand Up @@ -184,7 +297,10 @@ def __init__(self, dataset, args, ranks=64):
num_train = len(triples[0])
print('|Train|:', num_train)

if ranks > 1 and args.rel_part:
if ranks > 1 and args.soft_rel_part:
self.edge_parts, self.rel_parts, self.cross_part, self.cross_rels = \
SoftRelationPartition(triples, ranks)
elif ranks > 1 and args.rel_part:
self.edge_parts, self.rel_parts, self.cross_part = \
BalancedRelationPartition(triples, ranks)
elif ranks > 1:
Expand Down
1 change: 1 addition & 0 deletions apps/kg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def main(args):
args.valid = False
args.test = True
args.strict_rel_part = False
args.soft_rel_part = False
args.async_update = False
args.batch_size_eval = args.batch_size

Expand Down
18 changes: 14 additions & 4 deletions apps/kg/models/general_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(self, args, model_name, n_entities, n_relations, hidden_dim, gamma,
self.rel_dim = rel_dim
self.entity_dim = entity_dim
self.strict_rel_part = args.strict_rel_part
if not self.strict_rel_part:
self.soft_rel_part = args.soft_rel_part
if not self.strict_rel_part and not self.soft_rel_part:
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim,
F.cpu() if args.mix_cpu_gpu else device)
else:
Expand Down Expand Up @@ -120,7 +121,7 @@ def share_memory(self):
"""Use torch.tensor.share_memory_() to allow cross process embeddings access.
"""
self.entity_emb.share_memory()
if self.strict_rel_part:
if self.strict_rel_part or self.soft_rel_part:
self.global_relation_emb.share_memory()
else:
self.relation_emb.share_memory()
Expand All @@ -139,7 +140,7 @@ def save_emb(self, path, dataset):
Dataset name as prefix to the saved embeddings.
"""
self.entity_emb.save(path, dataset+'_'+self.model_name+'_entity')
if self.strict_rel_part:
if self.strict_rel_part or self.soft_rel_part:
self.global_relation_emb.save(path, dataset+'_'+self.model_name+'_relation')
else:
self.relation_emb.save(path, dataset+'_'+self.model_name+'_relation')
Expand All @@ -165,8 +166,10 @@ def reset_parameters(self):
"""
self.entity_emb.init(self.emb_init)
self.score_func.reset_parameters()
if not self.strict_rel_part:
if (not self.strict_rel_part) and (not self.soft_rel_part):
self.relation_emb.init(self.emb_init)
else:
self.global_relation_emb.init(self.emb_init)

def predict_score(self, g):
"""Predict the positive score.
Expand Down Expand Up @@ -415,6 +418,11 @@ def prepare_relation(self, device=None):
self.score_func.prepare_local_emb(local_projection_emb)
self.score_func.reset_parameters()

def prepare_cross_rels(self, cross_rels):
self.relation_emb.setup_cross_rels(cross_rels, self.global_relation_emb)
if self.model_name == 'TransR':
self.score_func.prepare_cross_rels(cross_rels)

def writeback_relation(self, rank=0, rel_parts=None):
""" Writeback relation embeddings in a specific process to global relation embedding.
Used in multi-process multi-gpu training model.
Expand All @@ -425,6 +433,8 @@ def writeback_relation(self, rank=0, rel_parts=None):
List of tensor stroing edge types of each partition.
"""
idx = rel_parts[rank]
if self.soft_rel_part:
idx = self.relation_emb.get_noncross_idx(idx)
self.global_relation_emb.emb[idx] = F.copy_to(self.relation_emb.emb, F.cpu())[idx]
if self.model_name == 'TransR':
self.score_func.writeback_local_emb(idx)
Expand Down
3 changes: 3 additions & 0 deletions apps/kg/models/pytorch/score_fun.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def prepare_local_emb(self, projection_emb):
self.global_projection_emb = self.projection_emb
self.projection_emb = projection_emb

def prepare_cross_rels(self, cross_rels):
self.projection_emb.setup_cross_rels(cross_rels, self.global_projection_emb)

def writeback_local_emb(self, idx):
self.global_projection_emb.emb[idx] = self.projection_emb.emb.cpu()[idx]

Expand Down
39 changes: 39 additions & 0 deletions apps/kg/models/pytorch/tensor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ class ExternalEmbedding:
def __init__(self, args, num, dim, device):
self.gpu = args.gpu
self.args = args
self.num = num
self.trace = []

self.emb = th.empty(num, dim, dtype=th.float32, device=device)
self.state_sum = self.emb.new().resize_(self.emb.size(0)).zero_()
self.state_step = 0
self.has_cross_rel = False
# queue used by asynchronous update
self.async_q = None
# asynchronous update process
Expand All @@ -138,6 +140,19 @@ def init(self, emb_init):
INIT.uniform_(self.emb, -emb_init, emb_init)
INIT.zeros_(self.state_sum)

def setup_cross_rels(self, cross_rels, global_emb):
cpu_bitmap = th.zeros((self.num,), dtype=th.bool)
for i, rel in enumerate(cross_rels):
cpu_bitmap[rel] = 1
self.cpu_bitmap = cpu_bitmap
self.has_cross_rel = True
self.global_emb = global_emb

def get_noncross_idx(self, idx):
cpu_mask = self.cpu_bitmap[idx]
gpu_mask = ~cpu_mask
return idx[gpu_mask]

def share_memory(self):
"""Use torch.tensor.share_memory_() to allow cross process tensor access
"""
Expand All @@ -158,6 +173,14 @@ def __call__(self, idx, gpu_id=-1, trace=True):
If False, do not trace the computation.
Default: True
"""
if self.has_cross_rel:
cpu_idx = idx.cpu()
cpu_mask = self.cpu_bitmap[cpu_idx]
cpu_idx = cpu_idx[cpu_mask]
cpu_idx = th.unique(cpu_idx)
if cpu_idx.shape[0] != 0:
cpu_emb = self.global_emb.emb[cpu_idx]
self.emb[cpu_idx] = cpu_emb.cuda(gpu_id)
s = self.emb[idx]
if gpu_id >= 0:
s = s.cuda(gpu_id)
Expand Down Expand Up @@ -202,6 +225,22 @@ def update(self, gpu_id=-1):
grad_indices = grad_indices.to(device)
if device != grad_sum.device:
grad_sum = grad_sum.to(device)

if self.has_cross_rel:
cpu_mask = self.cpu_bitmap[grad_indices]
cpu_idx = grad_indices[cpu_mask]
if cpu_idx.shape[0] > 0:
cpu_grad = grad_values[cpu_mask]
cpu_sum = grad_sum[cpu_mask].cpu()
cpu_idx = cpu_idx.cpu()
self.global_emb.state_sum.index_add_(0, cpu_idx, cpu_sum)
std = self.global_emb.state_sum[cpu_idx]
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-clr * cpu_grad / std_values)
tmp = tmp.cpu()
self.global_emb.emb.index_add_(0, cpu_idx, tmp)
self.state_sum.index_add_(0, grad_indices, grad_sum)
std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
Expand Down
8 changes: 7 additions & 1 deletion apps/kg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(self):
help='number of process used')
self.add_argument('--rel_part', action='store_true',
help='enable relation partitioning')
self.add_argument('--soft_rel_part', action='store_true',
help='enable soft relation partition')
self.add_argument('--nomp_thread_per_process', type=int, default=-1,
help='num of omp threads used per process in multi-process training')
self.add_argument('--async_update', action='store_true',
Expand Down Expand Up @@ -170,7 +172,9 @@ def run(args, logger):

num_workers = args.num_worker
train_data = TrainDataset(dataset, args, ranks=args.num_proc)
# if there is no cross partition relaiton, we fall back to strict_rel_part
args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part == False)
args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part

# Automatically set number of OMP threads for each process if it is not provided
# The value for GPU is evaluated in AWS p3.16xlarge
Expand Down Expand Up @@ -322,7 +326,8 @@ def run(args, logger):

# train
start = time.time()
rel_parts = train_data.rel_parts if args.strict_rel_part else None
rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
cross_rels = train_data.cross_rels if args.soft_rel_part else None
if args.num_proc > 1:
procs = []
barrier = mp.Barrier(args.num_proc)
Expand All @@ -334,6 +339,7 @@ def run(args, logger):
valid_sampler,
i,
rel_parts,
cross_rels,
barrier))
procs.append(proc)
proc.start()
Expand Down
Loading

0 comments on commit 49fe5b3

Please sign in to comment.