Skip to content

Commit

Permalink
[Example] Implement LINE with dgl and pytorch (dmlc#2195)
Browse files Browse the repository at this point in the history
* line

* two lines

* update readme

* readme

* update readme

* update

* Implement LINE

* readme

* readme

* typos

* update readme

Co-authored-by: Zihao Ye <[email protected]>
Co-authored-by: xiang song(charlie.song) <[email protected]>
Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
4 people authored Sep 20, 2020
1 parent 90d86fc commit eef4c05
Show file tree
Hide file tree
Showing 6 changed files with 1,161 additions and 0 deletions.
77 changes: 77 additions & 0 deletions examples/pytorch/ogb/line/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# LINE Example
- Paper link: [here](https://arxiv.org/pdf/1503.03578)
- Official implementation: [here](https://github.com/tangjianpku/LINE)

This implementation includes both LINE-1st and LINE-2nd. The detailed usage is shown in the arguments in line.py.

## How to load ogb data
To load ogb dataset, you need to run the following command, which will output a network file, ogbn-products-net.txt:
```
python3 load_dataset.py --name ogbn-proteins
```
Or you can run the code directly with:
```
python3 line.py --ogbn_name xxx --load_from_ogbn
```
However, ogb.nodeproppred might not be compatible with mixed training with multi-gpu. If you want to do mixed training, please use no more than 1 gpu by the command above. We leave the commands to run with multi-gpu at the end.

## Evaluation
For evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/).

## Used config
ogbn-arxiv
```
python3 line.py --save_in_pt --dim 128 --lap_norm 0.1 --mix --gpus 0 --batch_size 1024 --output_emb_file arxiv-embedding.pt --num_samples 1000 --print_interval 1000 --negative 5 --fast_neg --load_from_ogbn --ogbn_name ogbn-arxiv
cd ./ogb/blob/master/examples/nodeproppred/arxiv
cp embedding_pt_file_path ./
python3 mlp.py --device 0 --use_node_embedding
```

ogbn-proteins
```
python3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 1 --batch_size 1024 --output_emb_file protein-embedding.pt --num_samples 600 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-proteins --print_loss
cd ./ogb/blob/master/examples/nodeproppred/proteins
cp embedding_pt_file_path ./
python3 mlp.py --device 0 --use_node_embedding
```

ogbl-products
```
python3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 0 --batch_size 4096 --output_emb_file products-embedding.pt --num_samples 3000 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-products --print_loss
cd ./ogb/blob/master/examples/nodeproppred/products
cp embedding_pt_file_path ./
python3 mlp.py --device 0 --use_node_embedding
```

## Results
ogbn-arxiv
<br>#params: 33023343(model) + 142888(mlp) = 33166231
<br>Highest Train: 82.94 ± 0.11
<br>Highest Valid: 71.76 ± 0.08
<br>Final Train: 80.74 ± 1.30
<br>Final Test: 70.47 ± 0.19

<br>obgn-proteins
<br>#params: 25853524(model) + 129648(mlp) = 25983172
<br>Highest Train: 93.11 ± 0.04
<br>Highest Valid: 70.50 ± 1.29
<br>Final Train: 77.66 ± 10.27
<br>Final Test: 62.07 ± 1.25

<br>ogbn-products
<br>#params: 477570049(model) + 136495(mlp) = 477706544
<br>Highest Train: 98.01 ± 0.32
<br>Highest Valid: 89.57 ± 0.09
<br>Final Train: 94.96 ± 0.43
<br>Final Test: 72.52 ± 0.29

## Notes
To utlize multi-GPU training, we need to load datasets as a local file before training by the following command:
```
python3 load_dataset.py --name dataset_name
```
where `dataset_name` can be `ogbn-arxiv`, `ogbn-proteins`, and `ogbn-products`. After that, a local file `$dataset_name$-graph.bin` will be generated. Then run:
```
python3 line.py --data_file $dataset_name$-graph.bin
```
where the other parameters are the same with used configs without using `--load_from_ogbn` and `--ogbn_name`.
306 changes: 306 additions & 0 deletions examples/pytorch/ogb/line/line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import torch
import argparse
import dgl
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
import time
import numpy as np

from reading_data import LineDataset
from model import SkipGramModel
from utils import thread_wrapped_func, sum_up_params, check_args

class LineTrainer:
def __init__(self, args):
""" Initializing the trainer with the input arguments """
self.args = args
self.dataset = LineDataset(
net_file=args.data_file,
batch_size=args.batch_size,
negative=args.negative,
gpus=args.gpus,
fast_neg=args.fast_neg,
ogbl_name=args.ogbl_name,
load_from_ogbl=args.load_from_ogbl,
ogbn_name=args.ogbn_name,
load_from_ogbn=args.load_from_ogbn,
num_samples=args.num_samples * 1000000,
)
self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None

def init_device_emb(self):
""" set the device before training
will be called once in fast_train_mp / fast_train
"""
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"

# initializing embedding on CPU
self.emb_model = SkipGramModel(
emb_size=self.emb_size,
emb_dimension=self.args.dim,
batch_size=self.args.batch_size,
only_cpu=self.args.only_cpu,
only_gpu=self.args.only_gpu,
only_fst=self.args.only_fst,
only_snd=self.args.only_snd,
mix=self.args.mix,
neg_weight=self.args.neg_weight,
negative=self.args.negative,
lr=self.args.lr,
lap_norm=self.args.lap_norm,
fast_neg=self.args.fast_neg,
record_loss=self.args.print_loss,
async_update=self.args.async_update,
num_threads=self.args.num_threads,
)

torch.set_num_threads(self.args.num_threads)
if self.args.only_gpu:
print("Run in 1 GPU")
assert self.args.gpus[0] >= 0
self.emb_model.all_to_device(self.args.gpus[0])
elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have avaliable GPU'
self.emb_model.set_device(self.args.gpus[0])
else:
print("Run in CPU process")

def train(self):
""" train the embedding """
if len(self.args.gpus) > 1:
self.fast_train_mp()
else:
self.fast_train()

def fast_train_mp(self):
""" multi-cpu-core or mix cpu & multi-gpu """
self.init_device_emb()
self.emb_model.share_memory()

sum_up_params(self.emb_model)

start_all = time.time()
ps = []

for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
ps.append(p)
p.start()

for p in ps:
p.join()

print("Used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)

@thread_wrapped_func
def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """
if self.args.mix:
self.emb_model.set_device(gpu_id)

torch.set_num_threads(self.args.num_threads)
if self.args.async_update:
self.emb_model.create_async_update()

sampler = self.dataset.create_sampler(rank)

dataloader = DataLoader(
dataset=sampler.seeds,
batch_size=self.args.batch_size,
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=self.args.num_sampler_threads,
)
num_batches = len(dataloader)
print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))

start = time.time()
with torch.no_grad():
for i, edges in enumerate(dataloader):
if self.args.fast_neg:
self.emb_model.fast_learn(edges)
else:
# do negative sampling
bs = edges.size()[0]
neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table,
bs * self.args.negative,
replace=True))
self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)

if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss:
if self.args.only_fst:
print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval))
elif self.args.only_snd:
print("GPU-[%d] batch %d time: %.2fs snd-loss: %.4f" \
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval))
else:
print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \
% (gpu_id, i, time.time()-start, \
-sum(self.emb_model.loss_fst)/self.args.print_interval, \
-sum(self.emb_model.loss_snd)/self.args.print_interval))
self.emb_model.loss_fst = []
self.emb_model.loss_snd = []
else:
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
start = time.time()

if self.args.async_update:
self.emb_model.finish_async_update()

def fast_train(self):
""" fast train with dataloader with only gpu / only cpu"""
self.init_device_emb()

if self.args.async_update:
self.emb_model.share_memory()
self.emb_model.create_async_update()

sum_up_params(self.emb_model)

sampler = self.dataset.create_sampler(0)

dataloader = DataLoader(
dataset=sampler.seeds,
batch_size=self.args.batch_size,
collate_fn=sampler.sample,
shuffle=False,
drop_last=False,
num_workers=self.args.num_sampler_threads,
)

num_batches = len(dataloader)
print("num batchs: %d\n" % num_batches)

start_all = time.time()
start = time.time()
with torch.no_grad():
for i, edges in enumerate(dataloader):
if self.args.fast_neg:
self.emb_model.fast_learn(edges)
else:
# do negative sampling
bs = edges.size()[0]
neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table,
bs * self.args.negative,
replace=True))
self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)

if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss:
if self.args.only_fst:
print("Batch %d time: %.2fs fst-loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval))
elif self.args.only_snd:
print("Batch %d time: %.2fs snd-loss: %.4f" \
% (i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval))
else:
print("Batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \
% (i, time.time()-start, \
-sum(self.emb_model.loss_fst)/self.args.print_interval, \
-sum(self.emb_model.loss_snd)/self.args.print_interval))
self.emb_model.loss_fst = []
self.emb_model.loss_snd = []
else:
print("Batch %d, training time: %.2fs" % (i, time.time()-start))
start = time.time()

if self.args.async_update:
self.emb_model.finish_async_update()

print("Training used time: %.2fs" % (time.time()-start_all))
if self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Implementation of LINE.")
# input files
## personal datasets
parser.add_argument('--data_file', type=str,
help="path of dgl graphs")
## ogbl datasets
parser.add_argument('--ogbl_name', type=str,
help="name of ogbl dataset, e.g. ogbl-ddi")
parser.add_argument('--load_from_ogbl', default=False, action="store_true",
help="whether load dataset from ogbl")
parser.add_argument('--ogbn_name', type=str,
help="name of ogbn dataset, e.g. ogbn-proteins")
parser.add_argument('--load_from_ogbn', default=False, action="store_true",
help="whether load dataset from ogbn")

# output files
parser.add_argument('--save_in_pt', default=False, action="store_true",
help='Whether save dat in pt format or npy')
parser.add_argument('--output_emb_file', type=str, default="emb.npy",
help='path of the output npy embedding file')

# model parameters
parser.add_argument('--dim', default=128, type=int,
help="embedding dimensions")
parser.add_argument('--num_samples', default=1, type=int,
help="number of samples during training (million)")
parser.add_argument('--negative', default=1, type=int,
help="negative samples for each positve node pair")
parser.add_argument('--batch_size', default=128, type=int,
help="number of edges in each batch")
parser.add_argument('--neg_weight', default=1., type=float,
help="negative weight")
parser.add_argument('--lap_norm', default=0.01, type=float,
help="weight of laplacian normalization")

# training parameters
parser.add_argument('--only_fst', default=False, action="store_true",
help="only do first-order proximity embedding")
parser.add_argument('--only_snd', default=False, action="store_true",
help="only do second-order proximity embedding")
parser.add_argument('--print_interval', default=100, type=int,
help="number of batches between printing")
parser.add_argument('--print_loss', default=False, action="store_true",
help="whether print loss during training")
parser.add_argument('--lr', default=0.2, type=float,
help="learning rate")

# optimization settings
parser.add_argument('--mix', default=False, action="store_true",
help="mixed training with CPU and GPU")
parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
help='a list of active gpu ids, e.g. 0, used with --mix')
parser.add_argument('--only_cpu', default=False, action="store_true",
help="training with CPU")
parser.add_argument('--only_gpu', default=False, action="store_true",
help="training with a single GPU (all of the parameters are moved on the GPU)")
parser.add_argument('--async_update', default=False, action="store_true",
help="mixed training asynchronously, recommend not to use this")

parser.add_argument('--fast_neg', default=False, action="store_true",
help="do negative sampling inside a batch")
parser.add_argument('--num_threads', default=2, type=int,
help="number of threads used for each CPU-core/GPU")
parser.add_argument('--num_sampler_threads', default=2, type=int,
help="number of threads used for sampling")

args = parser.parse_args()

if args.async_update:
assert args.mix, "--async_update only with --mix"

start_time = time.time()
trainer = LineTrainer(args)
trainer.train()
print("Total used time: %.2f" % (time.time() - start_time))
Loading

0 comments on commit eef4c05

Please sign in to comment.