Skip to content

Commit

Permalink
[DGL-KE] Distributed training of DGL-KE (dmlc#1290)
Browse files Browse the repository at this point in the history
* update

* change name

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change worker number

* update

* update

* update

* update

* update

* update

* test

* update

* update

* update

* remove barrier

* max_step

* update

* add complex

* update

* chmod +x

* update

* update

* random partition

* random partition

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* change num_test_proc

* update num_thread

* update
  • Loading branch information
aksnzhy authored Mar 2, 2020
1 parent c3a3340 commit 00ba409
Show file tree
Hide file tree
Showing 13 changed files with 872 additions and 19 deletions.
87 changes: 75 additions & 12 deletions apps/kg/dataloader/KGDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class KGDataset1:
The triples are stored as 'head_name\trelation_name\ttail_name'.
'''
def __init__(self, path, name):
def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)

if not os.path.exists(os.path.join(path, name)):
Expand Down Expand Up @@ -66,9 +66,11 @@ def __init__(self, path, name):
self.n_entities = len(self.entity2id)
self.n_relations = len(self.relation2id)

self.train = self.read_triple(path, 'train')
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')
if read_triple == True:
self.train = self.read_triple(path, 'train')
if only_train == False:
self.valid = self.read_triple(path, 'valid')
self.test = self.read_triple(path, 'test')

def read_triple(self, path, mode):
# mode: train/valid/test
Expand Down Expand Up @@ -102,25 +104,32 @@ class KGDataset2:
The triples are stored as 'head_nid\trelation_id\ttail_nid'.
'''
def __init__(self, path, name):
def __init__(self, path, name, read_triple=True, only_train=False):
url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/{}.zip'.format(name)

if not os.path.exists(os.path.join(path, name)):
print('File not found. Downloading from', url)
_download_and_extract(url, path, '{}.zip'.format(name))
self.path = os.path.join(path, name)

f_ent2id = os.path.join(self.path, 'entity2id.txt')
f_rel2id = os.path.join(self.path, 'relation2id.txt')

with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])
with open(f_rel2id) as f_rel:
self.n_relations = int(f_rel.readline()[:-1])

self.train = self.read_triple(self.path, 'train')
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')
if only_train == True:
f_ent2id = os.path.join(self.path, 'local_to_global.txt')
with open(f_ent2id) as f_ent:
self.n_entities = len(f_ent.readlines())
else:
f_ent2id = os.path.join(self.path, 'entity2id.txt')
with open(f_ent2id) as f_ent:
self.n_entities = int(f_ent.readline()[:-1])

if read_triple == True:
self.train = self.read_triple(self.path, 'train')
if only_train == False:
self.valid = self.read_triple(self.path, 'valid')
self.test = self.read_triple(self.path, 'test')

def read_triple(self, path, mode, skip_first_line=False):
heads = []
Expand Down Expand Up @@ -151,3 +160,57 @@ def get_dataset(data_path, data_name, format_str):
dataset = KGDataset2(data_path, data_name)

return dataset


def get_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))

if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=True, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=True, only_train=True)

path = os.path.join(data_path, part_name)

partition_book = []
with open(os.path.join(path, 'partition_book.txt')) as f:
for line in f:
partition_book.append(int(line))

local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))

return dataset, partition_book, local_to_global


def get_server_partition_dataset(data_path, data_name, format_str, part_id):
part_name = os.path.join(data_name, 'part_'+str(part_id))

if data_name == 'Freebase':
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)
elif format_str == '1':
dataset = KGDataset1(data_path, part_name, read_triple=False, only_train=True)
else:
dataset = KGDataset2(data_path, part_name, read_triple=False, only_train=True)

path = os.path.join(data_path, part_name)

n_entities = len(open(os.path.join(path, 'partition_book.txt')).readlines())

local_to_global = []
with open(os.path.join(path, 'local_to_global.txt')) as f:
for line in f:
local_to_global.append(int(line))

global_to_local = [0] * n_entities
for i in range(len(local_to_global)):
global_id = local_to_global[i]
global_to_local[global_id] = i

local_to_global = None

return global_to_local, dataset
29 changes: 29 additions & 0 deletions apps/kg/distributed/freebase_complex.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
##################################################################################
# This script runing ComplEx model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2

# Delete the temp file
rm *-shape

##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))

while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model ComplEx --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done

##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model ComplEx --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
29 changes: 29 additions & 0 deletions apps/kg/distributed/freebase_distmult.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2

# Delete the temp file
rm *-shape

##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))

while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model DistMult --dataset Freebase \
--hidden_dim 400 --gamma 143.0 --lr 0.08 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done

##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model DistMult --dataset Freebase \
--batch_size 1024 --neg_sample_size 256 --hidden_dim 400 --gamma 143.0 --lr 0.08 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --total_machine 4 --num_client 40
29 changes: 29 additions & 0 deletions apps/kg/distributed/freebase_transe_l2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
##################################################################################
# This script runing distmult model on Freebase dataset in distributed setting.
# You can change the hyper-parameter in this file but DO NOT run script manually
##################################################################################
machine_id=$1
server_count=$2

# Delete the temp file
rm *-shape

##################################################################################
# Start kvserver
##################################################################################
SERVER_ID_LOW=$((machine_id*server_count))
SERVER_ID_HIGH=$(((machine_id+1)*server_count))

while [ $SERVER_ID_LOW -lt $SERVER_ID_HIGH ]
do
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvserver.py --model TransE_l2 --dataset Freebase \
--hidden_dim 400 --gamma 10 --lr 0.1 --total_client 160 --server_id $SERVER_ID_LOW &
let SERVER_ID_LOW+=1
done

##################################################################################
# Start kvclient
##################################################################################
MKL_NUM_THREADS=1 OMP_NUM_THREADS=1 DGLBACKEND=pytorch python3 ../kvclient.py --model TransE_l2 --dataset Freebase \
--batch_size 1000 --neg_sample_size 200 --hidden_dim 400 --gamma 10 --lr 0.1 --max_step 12500 --log_interval 100 \
--batch_size_eval 1000 --neg_sample_size_test 1000 --test -adv --regularization_coef 1e-9 --total_machine 4 --num_client 40
4 changes: 4 additions & 0 deletions apps/kg/distributed/ip_config.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
127.0.0.1 30050 8
25 changes: 25 additions & 0 deletions apps/kg/distributed/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
##################################################################################
# User runs this script to launch distrobited jobs on cluster
##################################################################################
script_path=~/dgl/apps/kg/distributed
script_file=./freebase_transe_l2.sh
user_name=ubuntu
ssh_key=~/mctt.pem

server_count=$(awk 'NR==1 {print $3}' ip_config.txt)

# run command on remote machine
LINE_LOW=2
LINE_HIGH=$(awk 'END{print NR}' ip_config.txt)
let LINE_HIGH+=1
s_id=0
while [ $LINE_LOW -lt $LINE_HIGH ]
do
ip=$(awk 'NR=='$LINE_LOW' {print $1}' ip_config.txt)
let LINE_LOW+=1
let s_id+=1
ssh -i $ssh_key $user_name@$ip 'cd '$script_path'; '$script_file' '$s_id' '$server_count' ' &
done

# run command on local machine
$script_file 0 $server_count
2 changes: 2 additions & 0 deletions apps/kg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self):
help='number of workers used for loading data')
self.add_argument('--num_proc', type=int, default=1,
help='number of process used')
self.add_argument('--num_thread', type=int, default=1,
help='number of thread used')

def parse_args(self):
args = super().parse_args()
Expand Down
Loading

0 comments on commit 00ba409

Please sign in to comment.