-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Implement deepwalk by dgl and pytorch (dmlc#1503)
* deepwalk * add docstr, etc. * add tested version * some doc * some doc * add sample and some docs, fix a bug * update speed * update speed * docs Co-authored-by: xiang song(charlie.song) <[email protected]>
- Loading branch information
1 parent
64f4970
commit 94c6720
Showing
5 changed files
with
989 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# DeepWalk | ||
|
||
- Paper link: [here](https://arxiv.org/pdf/1403.6652.pdf) | ||
- Other implementation: [gensim](https://github.com/phanein/deepwalk), [deepwalk-c](https://github.com/xgfs/deepwalk-c) | ||
|
||
The implementation includes multi-processing training with CPU and mixed training with CPU and multi-GPU. | ||
|
||
## Dependencies | ||
- PyTorch 1.0.1+ | ||
|
||
## Tested version | ||
- PyTorch 1.5.0 | ||
- DGL 0.4.3 | ||
|
||
## How to run the code | ||
|
||
Format of a network file: | ||
``` | ||
1(node id) 2(node id) | ||
1 3 | ||
... | ||
``` | ||
|
||
To run the code: | ||
``` | ||
python3 deepwalk.py --net_file net.txt --emb_file emb.txt --adam --mix --lr 0.2 --num_procs 4 --batch_size 100 --negative 5 | ||
``` | ||
|
||
## How to save the embedding | ||
|
||
Functions: | ||
``` | ||
SkipGramModel.save_embedding(dataset, file_name) | ||
SkipGramModel.save_embedding_txt(dataset, file_name) | ||
``` | ||
|
||
## Evaluation | ||
|
||
To evalutate embedding on multi-label classification, please refer to [here](https://github.com/ShawXh/Evaluate-Embedding) | ||
|
||
YouTube (1M nodes). | ||
|
||
| Implementation | Macro-F1 (%) <br> 1%    3%    5%    7%    9% | Micro-F1 (%) <br> 1%    3%    5%    7%    9% | | ||
|----|----|----| | ||
| gensim.word2vec(hs) | 28.73   32.51   33.67   34.28   34.79 | 35.73   38.34   39.37   40.08   40.77 | | ||
| gensim.word2vec(ns) | 28.18   32.25   33.56   34.60   35.22 | 35.35   37.69   38.08   40.24   41.09 | | ||
| ours | 24.58   31.23   33.97   35.41   36.48 | 38.93   43.17   44.73   45.42   45.92 | | ||
|
||
The comparison between running time is shown as below, where the numbers in the brackets denote time used on random-walk. | ||
|
||
| Implementation | gensim.word2vec(hs) | gensim.word2vec(ns) | Ours | | ||
|----|----|----|----| | ||
| Time (s) | 27119.6(1759.8) | 10580.3(1704.3) | 428.89 | | ||
|
||
Parameters. | ||
- walk_length = 80, number_walks = 10, window_size = 5 | ||
- Ours: 4GPU (Tesla V100), lr = 0.2, batchs_size = 128, neg_weight = 5, negative = 1, num_thread = 4 | ||
- Others: workers = 8, negative = 5 | ||
|
||
Speeding-up with mixed CPU & multi-GPU. The used parameters are the same as above. | ||
| #GPUs | 1 | 2 | 4 | | ||
|----------|-------|-------|-------| | ||
| Time (s) |1419.64| 952.04|428.89 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import torch | ||
import argparse | ||
import dgl | ||
import torch.multiprocessing as mp | ||
from torch.utils.data import DataLoader | ||
import os | ||
import random | ||
import time | ||
import numpy as np | ||
|
||
from reading_data import DeepwalkDataset | ||
from model import SkipGramModel | ||
from utils import thread_wrapped_func, shuffle_walks | ||
|
||
class DeepwalkTrainer: | ||
def __init__(self, args): | ||
""" Initializing the trainer with the input arguments """ | ||
self.args = args | ||
self.dataset = DeepwalkDataset( | ||
net_file=args.net_file, | ||
map_file=args.map_file, | ||
walk_length=args.walk_length, | ||
window_size=args.window_size, | ||
num_walks=args.num_walks, | ||
batch_size=args.batch_size, | ||
negative=args.negative, | ||
num_procs=args.num_procs, | ||
fast_neg=args.fast_neg, | ||
) | ||
self.emb_size = len(self.dataset.net) | ||
self.emb_model = None | ||
|
||
def init_device_emb(self): | ||
""" set the device before training | ||
will be called once in fast_train_mp / fast_train | ||
""" | ||
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix]) | ||
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]" | ||
assert self.args.num_procs >= 1, "The number of process must be larger than 1" | ||
choices = sum([self.args.sgd, self.args.adam, self.args.avg_sgd]) | ||
assert choices == 1, "Must choose only *one* gradient descent strategy in [sgd, avg_sgd, adam]" | ||
|
||
# initializing embedding on CPU | ||
self.emb_model = SkipGramModel( | ||
emb_size=self.emb_size, | ||
emb_dimension=self.args.dim, | ||
walk_length=self.args.walk_length, | ||
window_size=self.args.window_size, | ||
batch_size=self.args.batch_size, | ||
only_cpu=self.args.only_cpu, | ||
only_gpu=self.args.only_gpu, | ||
mix=self.args.mix, | ||
neg_weight=self.args.neg_weight, | ||
negative=self.args.negative, | ||
lr=self.args.lr, | ||
lap_norm=self.args.lap_norm, | ||
adam=self.args.adam, | ||
sgd=self.args.sgd, | ||
avg_sgd=self.args.avg_sgd, | ||
fast_neg=self.args.fast_neg, | ||
) | ||
|
||
torch.set_num_threads(self.args.num_threads) | ||
if self.args.only_gpu: | ||
print("Run in 1 GPU") | ||
self.emb_model.all_to_device(0) | ||
elif self.args.mix: | ||
print("Mix CPU with %d GPU" % self.args.num_procs) | ||
if self.args.num_procs == 1: | ||
self.emb_model.set_device(0) | ||
else: | ||
print("Run in %d CPU process" % self.args.num_procs) | ||
|
||
def train(self): | ||
""" train the embedding """ | ||
if self.args.num_procs > 1: | ||
self.fast_train_mp() | ||
else: | ||
self.fast_train() | ||
|
||
def fast_train_mp(self): | ||
""" multi-cpu-core or mix cpu & multi-gpu """ | ||
self.init_device_emb() | ||
self.emb_model.share_memory() | ||
|
||
start_all = time.time() | ||
ps = [] | ||
|
||
np_ = self.args.num_procs | ||
for i in range(np_): | ||
p = mp.Process(target=self.fast_train_sp, args=(i,)) | ||
ps.append(p) | ||
p.start() | ||
|
||
for p in ps: | ||
p.join() | ||
|
||
print("Used time: %.2fs" % (time.time()-start_all)) | ||
self.emb_model.save_embedding(self.dataset, self.args.emb_file) | ||
|
||
@thread_wrapped_func | ||
def fast_train_sp(self, gpu_id): | ||
""" a subprocess for fast_train_mp """ | ||
if self.args.mix: | ||
self.emb_model.set_device(gpu_id) | ||
torch.set_num_threads(self.args.num_threads) | ||
|
||
sampler = self.dataset.create_sampler(gpu_id) | ||
|
||
dataloader = DataLoader( | ||
dataset=sampler.seeds, | ||
batch_size=self.args.batch_size, | ||
collate_fn=sampler.sample, | ||
shuffle=False, | ||
drop_last=False, | ||
num_workers=4, | ||
) | ||
num_batches = len(dataloader) | ||
print("num batchs: %d in subprocess [%d]" % (num_batches, gpu_id)) | ||
# number of positive node pairs in a sequence | ||
num_pos = int(2 * self.args.walk_length * self.args.window_size\ | ||
- self.args.window_size * (self.args.window_size + 1)) | ||
|
||
start = time.time() | ||
with torch.no_grad(): | ||
max_i = self.args.iterations * num_batches | ||
|
||
for i, walks in enumerate(dataloader): | ||
# decay learning rate for SGD | ||
lr = self.args.lr * (max_i - i) / max_i | ||
if lr < 0.00001: | ||
lr = 0.00001 | ||
|
||
if self.args.fast_neg: | ||
self.emb_model.fast_learn(walks, lr) | ||
else: | ||
# do negative sampling | ||
bs = len(walks) | ||
neg_nodes = torch.LongTensor( | ||
np.random.choice(self.dataset.neg_table, | ||
bs * num_pos * self.args.negative, | ||
replace=True)) | ||
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes) | ||
|
||
if i > 0 and i % self.args.print_interval == 0: | ||
print("Solver [%d] batch %d tt: %.2fs" % (gpu_id, i, time.time()-start)) | ||
start = time.time() | ||
|
||
def fast_train(self): | ||
""" fast train with dataloader """ | ||
# the number of postive node pairs of a node sequence | ||
num_pos = 2 * self.args.walk_length * self.args.window_size\ | ||
- self.args.window_size * (self.args.window_size + 1) | ||
num_pos = int(num_pos) | ||
|
||
self.init_device_emb() | ||
|
||
sampler = self.dataset.create_sampler(0) | ||
|
||
dataloader = DataLoader( | ||
dataset=sampler.seeds, | ||
batch_size=self.args.batch_size, | ||
collate_fn=sampler.sample, | ||
shuffle=False, | ||
drop_last=False, | ||
num_workers=4, | ||
) | ||
|
||
num_batches = len(dataloader) | ||
print("num batchs: %d" % num_batches) | ||
|
||
start_all = time.time() | ||
start = time.time() | ||
with torch.no_grad(): | ||
max_i = self.args.iterations * num_batches | ||
for iteration in range(self.args.iterations): | ||
print("\nIteration: " + str(iteration + 1)) | ||
|
||
for i, walks in enumerate(dataloader): | ||
# decay learning rate for SGD | ||
lr = self.args.lr * (max_i - i) / max_i | ||
if lr < 0.00001: | ||
lr = 0.00001 | ||
|
||
if self.args.fast_neg: | ||
self.emb_model.fast_learn(walks, lr) | ||
else: | ||
# do negative sampling | ||
bs = len(walks) | ||
neg_nodes = torch.LongTensor( | ||
np.random.choice(self.dataset.neg_table, | ||
bs * num_pos * self.args.negative, | ||
replace=True)) | ||
self.emb_model.fast_learn(walks, lr, neg_nodes=neg_nodes) | ||
|
||
if i > 0 and i % self.args.print_interval == 0: | ||
print("Batch %d, training time: %.2fs" % (i, time.time()-start)) | ||
start = time.time() | ||
|
||
print("Training used time: %.2fs" % (time.time()-start_all)) | ||
self.emb_model.save_embedding(self.dataset, self.args.emb_file) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description="DeepWalk") | ||
parser.add_argument('--net_file', type=str, | ||
help="path of the txt network file") | ||
parser.add_argument('--emb_file', type=str, default="emb.npy", | ||
help='path of the npy embedding file') | ||
parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle", | ||
help='path of the mapping dict that maps node ids to embedding index') | ||
parser.add_argument('--dim', default=128, type=int, | ||
help="embedding dimensions") | ||
parser.add_argument('--window_size', default=5, type=int, | ||
help="context window size") | ||
parser.add_argument('--num_walks', default=10, type=int, | ||
help="number of walks for each node") | ||
parser.add_argument('--negative', default=5, type=int, | ||
help="negative samples for each positve node pair") | ||
parser.add_argument('--iterations', default=1, type=int, | ||
help="iterations") | ||
parser.add_argument('--batch_size', default=10, type=int, | ||
help="number of node sequences in each batch") | ||
parser.add_argument('--print_interval', default=1000, type=int, | ||
help="number of batches between printing") | ||
parser.add_argument('--walk_length', default=80, type=int, | ||
help="number of nodes in a sequence") | ||
parser.add_argument('--lr', default=0.2, type=float, | ||
help="learning rate") | ||
parser.add_argument('--neg_weight', default=1., type=float, | ||
help="negative weight") | ||
parser.add_argument('--lap_norm', default=0.01, type=float, | ||
help="weight of laplacian normalization") | ||
parser.add_argument('--mix', default=False, action="store_true", | ||
help="mixed training with CPU and GPU") | ||
parser.add_argument('--only_cpu', default=False, action="store_true", | ||
help="training with CPU") | ||
parser.add_argument('--only_gpu', default=False, action="store_true", | ||
help="training with GPU") | ||
parser.add_argument('--fast_neg', default=True, action="store_true", | ||
help="do negative sampling inside a batch") | ||
parser.add_argument('--adam', default=False, action="store_true", | ||
help="use adam for embedding updation, recommended") | ||
parser.add_argument('--sgd', default=False, action="store_true", | ||
help="use sgd for embedding updation") | ||
parser.add_argument('--avg_sgd', default=False, action="store_true", | ||
help="average gradients of sgd for embedding updation") | ||
parser.add_argument('--num_threads', default=2, type=int, | ||
help="number of threads used for each CPU-core/GPU") | ||
parser.add_argument('--num_procs', default=1, type=int, | ||
help="number of GPUs/CPUs when mixed training") | ||
args = parser.parse_args() | ||
|
||
start_time = time.time() | ||
trainer = DeepwalkTrainer(args) | ||
trainer.train() | ||
print("Total used time: %.2f" % (time.time() - start_time)) |
Oops, something went wrong.