diff --git a/examples/pytorch/ogb/line/README.md b/examples/pytorch/ogb/line/README.md
new file mode 100644
index 000000000000..091cad170bbe
--- /dev/null
+++ b/examples/pytorch/ogb/line/README.md
@@ -0,0 +1,77 @@
+# LINE Example
+- Paper link: [here](https://arxiv.org/pdf/1503.03578)
+- Official implementation: [here](https://github.com/tangjianpku/LINE)
+
+This implementation includes both LINE-1st and LINE-2nd. The detailed usage is shown in the arguments in line.py.
+
+## How to load ogb data
+To load ogb dataset, you need to run the following command, which will output a network file, ogbn-products-net.txt:
+```
+python3 load_dataset.py --name ogbn-proteins
+```
+Or you can run the code directly with:
+```
+python3 line.py --ogbn_name xxx --load_from_ogbn
+```
+However, ogb.nodeproppred might not be compatible with mixed training with multi-gpu. If you want to do mixed training, please use no more than 1 gpu by the command above. We leave the commands to run with multi-gpu at the end.
+
+## Evaluation
+For evaluatation we follow the code mlp.py provided by ogb [here](https://github.com/snap-stanford/ogb/blob/master/examples/nodeproppred/).
+
+## Used config
+ogbn-arxiv
+```
+python3 line.py --save_in_pt --dim 128 --lap_norm 0.1 --mix --gpus 0 --batch_size 1024 --output_emb_file arxiv-embedding.pt --num_samples 1000 --print_interval 1000 --negative 5 --fast_neg --load_from_ogbn --ogbn_name ogbn-arxiv
+cd ./ogb/blob/master/examples/nodeproppred/arxiv
+cp embedding_pt_file_path ./
+python3 mlp.py --device 0 --use_node_embedding
+```
+
+ogbn-proteins
+```
+python3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 1 --batch_size 1024 --output_emb_file protein-embedding.pt --num_samples 600 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-proteins --print_loss
+cd ./ogb/blob/master/examples/nodeproppred/proteins
+cp embedding_pt_file_path ./
+python3 mlp.py --device 0 --use_node_embedding
+```
+
+ogbl-products
+```
+python3 line.py --save_in_pt --dim 128 --lap_norm 0.01 --mix --gpus 0 --batch_size 4096 --output_emb_file products-embedding.pt --num_samples 3000 --print_interval 1000 --negative 1 --fast_neg --load_from_ogbn --ogbn_name ogbn-products --print_loss
+cd ./ogb/blob/master/examples/nodeproppred/products
+cp embedding_pt_file_path ./
+python3 mlp.py --device 0 --use_node_embedding
+```
+
+## Results
+ogbn-arxiv
+
#params: 33023343(model) + 142888(mlp) = 33166231
+
Highest Train: 82.94 ± 0.11
+
Highest Valid: 71.76 ± 0.08
+
Final Train: 80.74 ± 1.30
+
Final Test: 70.47 ± 0.19
+
+
obgn-proteins
+
#params: 25853524(model) + 129648(mlp) = 25983172
+
Highest Train: 93.11 ± 0.04
+
Highest Valid: 70.50 ± 1.29
+
Final Train: 77.66 ± 10.27
+
Final Test: 62.07 ± 1.25
+
+
ogbn-products
+
#params: 477570049(model) + 136495(mlp) = 477706544
+
Highest Train: 98.01 ± 0.32
+
Highest Valid: 89.57 ± 0.09
+
Final Train: 94.96 ± 0.43
+
Final Test: 72.52 ± 0.29
+
+## Notes
+To utlize multi-GPU training, we need to load datasets as a local file before training by the following command:
+```
+python3 load_dataset.py --name dataset_name
+```
+where `dataset_name` can be `ogbn-arxiv`, `ogbn-proteins`, and `ogbn-products`. After that, a local file `$dataset_name$-graph.bin` will be generated. Then run:
+```
+python3 line.py --data_file $dataset_name$-graph.bin
+```
+where the other parameters are the same with used configs without using `--load_from_ogbn` and `--ogbn_name`.
\ No newline at end of file
diff --git a/examples/pytorch/ogb/line/line.py b/examples/pytorch/ogb/line/line.py
new file mode 100644
index 000000000000..ef36199531d1
--- /dev/null
+++ b/examples/pytorch/ogb/line/line.py
@@ -0,0 +1,306 @@
+import torch
+import argparse
+import dgl
+import torch.multiprocessing as mp
+from torch.utils.data import DataLoader
+import os
+import random
+import time
+import numpy as np
+
+from reading_data import LineDataset
+from model import SkipGramModel
+from utils import thread_wrapped_func, sum_up_params, check_args
+
+class LineTrainer:
+ def __init__(self, args):
+ """ Initializing the trainer with the input arguments """
+ self.args = args
+ self.dataset = LineDataset(
+ net_file=args.data_file,
+ batch_size=args.batch_size,
+ negative=args.negative,
+ gpus=args.gpus,
+ fast_neg=args.fast_neg,
+ ogbl_name=args.ogbl_name,
+ load_from_ogbl=args.load_from_ogbl,
+ ogbn_name=args.ogbn_name,
+ load_from_ogbn=args.load_from_ogbn,
+ num_samples=args.num_samples * 1000000,
+ )
+ self.emb_size = self.dataset.G.number_of_nodes()
+ self.emb_model = None
+
+ def init_device_emb(self):
+ """ set the device before training
+ will be called once in fast_train_mp / fast_train
+ """
+ choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
+ assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
+
+ # initializing embedding on CPU
+ self.emb_model = SkipGramModel(
+ emb_size=self.emb_size,
+ emb_dimension=self.args.dim,
+ batch_size=self.args.batch_size,
+ only_cpu=self.args.only_cpu,
+ only_gpu=self.args.only_gpu,
+ only_fst=self.args.only_fst,
+ only_snd=self.args.only_snd,
+ mix=self.args.mix,
+ neg_weight=self.args.neg_weight,
+ negative=self.args.negative,
+ lr=self.args.lr,
+ lap_norm=self.args.lap_norm,
+ fast_neg=self.args.fast_neg,
+ record_loss=self.args.print_loss,
+ async_update=self.args.async_update,
+ num_threads=self.args.num_threads,
+ )
+
+ torch.set_num_threads(self.args.num_threads)
+ if self.args.only_gpu:
+ print("Run in 1 GPU")
+ assert self.args.gpus[0] >= 0
+ self.emb_model.all_to_device(self.args.gpus[0])
+ elif self.args.mix:
+ print("Mix CPU with %d GPU" % len(self.args.gpus))
+ if len(self.args.gpus) == 1:
+ assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have avaliable GPU'
+ self.emb_model.set_device(self.args.gpus[0])
+ else:
+ print("Run in CPU process")
+
+ def train(self):
+ """ train the embedding """
+ if len(self.args.gpus) > 1:
+ self.fast_train_mp()
+ else:
+ self.fast_train()
+
+ def fast_train_mp(self):
+ """ multi-cpu-core or mix cpu & multi-gpu """
+ self.init_device_emb()
+ self.emb_model.share_memory()
+
+ sum_up_params(self.emb_model)
+
+ start_all = time.time()
+ ps = []
+
+ for i in range(len(self.args.gpus)):
+ p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
+ ps.append(p)
+ p.start()
+
+ for p in ps:
+ p.join()
+
+ print("Used time: %.2fs" % (time.time()-start_all))
+ if self.args.save_in_pt:
+ self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
+ else:
+ self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
+
+ @thread_wrapped_func
+ def fast_train_sp(self, rank, gpu_id):
+ """ a subprocess for fast_train_mp """
+ if self.args.mix:
+ self.emb_model.set_device(gpu_id)
+
+ torch.set_num_threads(self.args.num_threads)
+ if self.args.async_update:
+ self.emb_model.create_async_update()
+
+ sampler = self.dataset.create_sampler(rank)
+
+ dataloader = DataLoader(
+ dataset=sampler.seeds,
+ batch_size=self.args.batch_size,
+ collate_fn=sampler.sample,
+ shuffle=False,
+ drop_last=False,
+ num_workers=self.args.num_sampler_threads,
+ )
+ num_batches = len(dataloader)
+ print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
+
+ start = time.time()
+ with torch.no_grad():
+ for i, edges in enumerate(dataloader):
+ if self.args.fast_neg:
+ self.emb_model.fast_learn(edges)
+ else:
+ # do negative sampling
+ bs = edges.size()[0]
+ neg_nodes = torch.LongTensor(
+ np.random.choice(self.dataset.neg_table,
+ bs * self.args.negative,
+ replace=True))
+ self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)
+
+ if i > 0 and i % self.args.print_interval == 0:
+ if self.args.print_loss:
+ if self.args.only_fst:
+ print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f" \
+ % (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval))
+ elif self.args.only_snd:
+ print("GPU-[%d] batch %d time: %.2fs snd-loss: %.4f" \
+ % (gpu_id, i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval))
+ else:
+ print("GPU-[%d] batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \
+ % (gpu_id, i, time.time()-start, \
+ -sum(self.emb_model.loss_fst)/self.args.print_interval, \
+ -sum(self.emb_model.loss_snd)/self.args.print_interval))
+ self.emb_model.loss_fst = []
+ self.emb_model.loss_snd = []
+ else:
+ print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
+ start = time.time()
+
+ if self.args.async_update:
+ self.emb_model.finish_async_update()
+
+ def fast_train(self):
+ """ fast train with dataloader with only gpu / only cpu"""
+ self.init_device_emb()
+
+ if self.args.async_update:
+ self.emb_model.share_memory()
+ self.emb_model.create_async_update()
+
+ sum_up_params(self.emb_model)
+
+ sampler = self.dataset.create_sampler(0)
+
+ dataloader = DataLoader(
+ dataset=sampler.seeds,
+ batch_size=self.args.batch_size,
+ collate_fn=sampler.sample,
+ shuffle=False,
+ drop_last=False,
+ num_workers=self.args.num_sampler_threads,
+ )
+
+ num_batches = len(dataloader)
+ print("num batchs: %d\n" % num_batches)
+
+ start_all = time.time()
+ start = time.time()
+ with torch.no_grad():
+ for i, edges in enumerate(dataloader):
+ if self.args.fast_neg:
+ self.emb_model.fast_learn(edges)
+ else:
+ # do negative sampling
+ bs = edges.size()[0]
+ neg_nodes = torch.LongTensor(
+ np.random.choice(self.dataset.neg_table,
+ bs * self.args.negative,
+ replace=True))
+ self.emb_model.fast_learn(edges, neg_nodes=neg_nodes)
+
+ if i > 0 and i % self.args.print_interval == 0:
+ if self.args.print_loss:
+ if self.args.only_fst:
+ print("Batch %d time: %.2fs fst-loss: %.4f" \
+ % (i, time.time()-start, -sum(self.emb_model.loss_fst)/self.args.print_interval))
+ elif self.args.only_snd:
+ print("Batch %d time: %.2fs snd-loss: %.4f" \
+ % (i, time.time()-start, -sum(self.emb_model.loss_snd)/self.args.print_interval))
+ else:
+ print("Batch %d time: %.2fs fst-loss: %.4f snd-loss: %.4f" \
+ % (i, time.time()-start, \
+ -sum(self.emb_model.loss_fst)/self.args.print_interval, \
+ -sum(self.emb_model.loss_snd)/self.args.print_interval))
+ self.emb_model.loss_fst = []
+ self.emb_model.loss_snd = []
+ else:
+ print("Batch %d, training time: %.2fs" % (i, time.time()-start))
+ start = time.time()
+
+ if self.args.async_update:
+ self.emb_model.finish_async_update()
+
+ print("Training used time: %.2fs" % (time.time()-start_all))
+ if self.args.save_in_pt:
+ self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
+ else:
+ self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="Implementation of LINE.")
+ # input files
+ ## personal datasets
+ parser.add_argument('--data_file', type=str,
+ help="path of dgl graphs")
+ ## ogbl datasets
+ parser.add_argument('--ogbl_name', type=str,
+ help="name of ogbl dataset, e.g. ogbl-ddi")
+ parser.add_argument('--load_from_ogbl', default=False, action="store_true",
+ help="whether load dataset from ogbl")
+ parser.add_argument('--ogbn_name', type=str,
+ help="name of ogbn dataset, e.g. ogbn-proteins")
+ parser.add_argument('--load_from_ogbn', default=False, action="store_true",
+ help="whether load dataset from ogbn")
+
+ # output files
+ parser.add_argument('--save_in_pt', default=False, action="store_true",
+ help='Whether save dat in pt format or npy')
+ parser.add_argument('--output_emb_file', type=str, default="emb.npy",
+ help='path of the output npy embedding file')
+
+ # model parameters
+ parser.add_argument('--dim', default=128, type=int,
+ help="embedding dimensions")
+ parser.add_argument('--num_samples', default=1, type=int,
+ help="number of samples during training (million)")
+ parser.add_argument('--negative', default=1, type=int,
+ help="negative samples for each positve node pair")
+ parser.add_argument('--batch_size', default=128, type=int,
+ help="number of edges in each batch")
+ parser.add_argument('--neg_weight', default=1., type=float,
+ help="negative weight")
+ parser.add_argument('--lap_norm', default=0.01, type=float,
+ help="weight of laplacian normalization")
+
+ # training parameters
+ parser.add_argument('--only_fst', default=False, action="store_true",
+ help="only do first-order proximity embedding")
+ parser.add_argument('--only_snd', default=False, action="store_true",
+ help="only do second-order proximity embedding")
+ parser.add_argument('--print_interval', default=100, type=int,
+ help="number of batches between printing")
+ parser.add_argument('--print_loss', default=False, action="store_true",
+ help="whether print loss during training")
+ parser.add_argument('--lr', default=0.2, type=float,
+ help="learning rate")
+
+ # optimization settings
+ parser.add_argument('--mix', default=False, action="store_true",
+ help="mixed training with CPU and GPU")
+ parser.add_argument('--gpus', type=int, default=[-1], nargs='+',
+ help='a list of active gpu ids, e.g. 0, used with --mix')
+ parser.add_argument('--only_cpu', default=False, action="store_true",
+ help="training with CPU")
+ parser.add_argument('--only_gpu', default=False, action="store_true",
+ help="training with a single GPU (all of the parameters are moved on the GPU)")
+ parser.add_argument('--async_update', default=False, action="store_true",
+ help="mixed training asynchronously, recommend not to use this")
+
+ parser.add_argument('--fast_neg', default=False, action="store_true",
+ help="do negative sampling inside a batch")
+ parser.add_argument('--num_threads', default=2, type=int,
+ help="number of threads used for each CPU-core/GPU")
+ parser.add_argument('--num_sampler_threads', default=2, type=int,
+ help="number of threads used for sampling")
+
+ args = parser.parse_args()
+
+ if args.async_update:
+ assert args.mix, "--async_update only with --mix"
+
+ start_time = time.time()
+ trainer = LineTrainer(args)
+ trainer.train()
+ print("Total used time: %.2f" % (time.time() - start_time))
diff --git a/examples/pytorch/ogb/line/load_dataset.py b/examples/pytorch/ogb/line/load_dataset.py
new file mode 100644
index 000000000000..77bbd8382193
--- /dev/null
+++ b/examples/pytorch/ogb/line/load_dataset.py
@@ -0,0 +1,36 @@
+""" load dataset from ogb """
+
+import argparse
+from ogb.linkproppred import DglLinkPropPredDataset
+from ogb.nodeproppred import DglNodePropPredDataset
+import dgl
+
+def load_from_ogbl_with_name(name):
+ choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation']
+ assert name in choices, "name must be selected from " + str(choices)
+ dataset = DglLinkPropPredDataset(name)
+ return dataset[0]
+
+def load_from_ogbn_with_name(name):
+ choices = ['ogbn-products', 'ogbn-proteins', 'ogbn-arxiv', 'ogbn-papers100M']
+ assert name in choices, "name must be selected from " + str(choices)
+ dataset, label = DglNodePropPredDataset(name)[0]
+ return dataset
+
+if __name__ == "__main__":
+ """ load datasets as net.txt format """
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--name', type=str,
+ choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation',
+ 'ogbn-products', 'ogbn-proteins', 'ogbn-arxiv', 'ogbn-papers100M'],
+ default='ogbl-collab',
+ help="name of datasets by ogb")
+ args = parser.parse_args()
+
+ name = args.name
+ if name.startswith("ogbl"):
+ g = load_from_ogbl_with_name(name=name)
+ else:
+ g = load_from_ogbn_with_name(name=name)
+
+ dgl.save_graphs(name + "-graph.bin", g)
\ No newline at end of file
diff --git a/examples/pytorch/ogb/line/model.py b/examples/pytorch/ogb/line/model.py
new file mode 100644
index 000000000000..bb074eb22ff0
--- /dev/null
+++ b/examples/pytorch/ogb/line/model.py
@@ -0,0 +1,469 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import random
+import numpy as np
+import torch.multiprocessing as mp
+from torch.multiprocessing import Queue
+
+from utils import thread_wrapped_func
+
+def init_emb2neg_index(negative, batch_size):
+ '''select embedding of negative nodes from a batch of node embeddings
+ for fast negative sampling
+
+ Return
+ ------
+ index_emb_negu torch.LongTensor : the indices of u_embeddings
+ index_emb_negv torch.LongTensor : the indices of v_embeddings
+
+ Usage
+ -----
+ # emb_u.shape: [batch_size, dim]
+ batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu)
+ '''
+ idx_list_u = list(range(batch_size)) * negative
+ idx_list_v = list(range(batch_size)) * negative
+ random.shuffle(idx_list_v)
+
+ index_emb_negu = torch.LongTensor(idx_list_u)
+ index_emb_negv = torch.LongTensor(idx_list_v)
+
+ return index_emb_negu, index_emb_negv
+
+def adam(grad, state_sum, nodes, lr, device, only_gpu):
+ """ calculate gradients according to adam """
+ grad_sum = (grad * grad).mean(1)
+ if not only_gpu:
+ grad_sum = grad_sum.cpu()
+ state_sum.index_add_(0, nodes, grad_sum) # cpu
+ std = state_sum[nodes].to(device) # gpu
+ std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
+ grad = (lr * grad / std_values) # gpu
+
+ return grad
+
+@thread_wrapped_func
+def async_update(num_threads, model, queue):
+ """ Asynchronous embedding update for entity embeddings.
+ """
+ torch.set_num_threads(num_threads)
+ print("async start")
+ while True:
+ (grad_u, grad_v, grad_v_neg, nodes, neg_nodes, first_flag) = queue.get()
+ if grad_u is None:
+ return
+ with torch.no_grad():
+ if first_flag:
+ model.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u)
+ model.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v)
+ if neg_nodes is not None:
+ model.fst_u_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg)
+ else:
+ model.snd_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u)
+ model.snd_v_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v)
+ if neg_nodes is not None:
+ model.snd_v_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg)
+
+class SkipGramModel(nn.Module):
+ """ Negative sampling based skip-gram """
+ def __init__(self,
+ emb_size,
+ emb_dimension,
+ batch_size,
+ only_cpu,
+ only_gpu,
+ only_fst,
+ only_snd,
+ mix,
+ neg_weight,
+ negative,
+ lr,
+ lap_norm,
+ fast_neg,
+ record_loss,
+ async_update,
+ num_threads,
+ ):
+ """ initialize embedding on CPU
+
+ Paremeters
+ ----------
+ emb_size int : number of nodes
+ emb_dimension int : embedding dimension
+ batch_size int : number of node sequences in each batch
+ only_cpu bool : training with CPU
+ only_gpu bool : training with GPU
+ only_fst bool : only embedding for first-order proximity
+ only_snd bool : only embedding for second-order proximity
+ mix bool : mixed training with CPU and GPU
+ negative int : negative samples for each positve node pair
+ neg_weight float : negative weight
+ lr float : initial learning rate
+ lap_norm float : weight of laplacian normalization
+ fast_neg bool : do negative sampling inside a batch
+ record_loss bool : print the loss during training
+ use_context_weight : give different weights to the nodes in a context window
+ async_update : asynchronous training
+ """
+ super(SkipGramModel, self).__init__()
+ self.emb_size = emb_size
+ self.batch_size = batch_size
+ self.only_cpu = only_cpu
+ self.only_gpu = only_gpu
+ if only_fst:
+ self.fst = True
+ self.snd = False
+ self.emb_dimension = emb_dimension
+ elif only_snd:
+ self.fst = False
+ self.snd = True
+ self.emb_dimension = emb_dimension
+ else:
+ self.fst = True
+ self.snd = True
+ self.emb_dimension = int(emb_dimension / 2)
+ self.mixed_train = mix
+ self.neg_weight = neg_weight
+ self.negative = negative
+ self.lr = lr
+ self.lap_norm = lap_norm
+ self.fast_neg = fast_neg
+ self.record_loss = record_loss
+ self.async_update = async_update
+ self.num_threads = num_threads
+
+ # initialize the device as cpu
+ self.device = torch.device("cpu")
+
+ # embedding
+ initrange = 1.0 / self.emb_dimension
+ if self.fst:
+ self.fst_u_embeddings = nn.Embedding(
+ self.emb_size, self.emb_dimension, sparse=True)
+ init.uniform_(self.fst_u_embeddings.weight.data, -initrange, initrange)
+ if self.snd:
+ self.snd_u_embeddings = nn.Embedding(
+ self.emb_size, self.emb_dimension, sparse=True)
+ init.uniform_(self.snd_u_embeddings.weight.data, -initrange, initrange)
+ self.snd_v_embeddings = nn.Embedding(
+ self.emb_size, self.emb_dimension, sparse=True)
+ init.constant_(self.snd_v_embeddings.weight.data, 0)
+
+ # lookup_table is used for fast sigmoid computing
+ self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
+ self.lookup_table[0] = 0.
+ self.lookup_table[-1] = 1.
+ if self.record_loss:
+ self.logsigmoid_table = torch.log(torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)))
+ self.loss_fst = []
+ self.loss_snd = []
+
+ # indexes to select positive/negative node pairs from batch_walks
+ self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(self.negative, self.batch_size)
+
+ # adam
+ if self.fst:
+ self.fst_state_sum_u = torch.zeros(self.emb_size)
+ if self.snd:
+ self.snd_state_sum_u = torch.zeros(self.emb_size)
+ self.snd_state_sum_v = torch.zeros(self.emb_size)
+
+ def create_async_update(self):
+ """ Set up the async update subprocess.
+ """
+ self.async_q = Queue(1)
+ self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q))
+ self.async_p.start()
+
+ def finish_async_update(self):
+ """ Notify the async update subprocess to quit.
+ """
+ self.async_q.put((None, None, None, None, None))
+ self.async_p.join()
+
+ def share_memory(self):
+ """ share the parameters across subprocesses """
+ if self.fst:
+ self.fst_u_embeddings.weight.share_memory_()
+ self.fst_state_sum_u.share_memory_()
+ if self.snd:
+ self.snd_u_embeddings.weight.share_memory_()
+ self.snd_v_embeddings.weight.share_memory_()
+ self.snd_state_sum_u.share_memory_()
+ self.snd_state_sum_v.share_memory_()
+
+ def set_device(self, gpu_id):
+ """ set gpu device """
+ self.device = torch.device("cuda:%d" % gpu_id)
+ print("The device is", self.device)
+ self.lookup_table = self.lookup_table.to(self.device)
+ if self.record_loss:
+ self.logsigmoid_table = self.logsigmoid_table.to(self.device)
+ self.index_emb_negu = self.index_emb_negu.to(self.device)
+ self.index_emb_negv = self.index_emb_negv.to(self.device)
+
+ def all_to_device(self, gpu_id):
+ """ move all of the parameters to a single GPU """
+ self.device = torch.device("cuda:%d" % gpu_id)
+ self.set_device(gpu_id)
+ if self.fst:
+ self.fst_u_embeddings = self.fst_u_embeddings.cuda(gpu_id)
+ self.fst_state_sum_u = self.fst_state_sum_u.to(self.device)
+ if self.snd:
+ self.snd_u_embeddings = self.snd_u_embeddings.cuda(gpu_id)
+ self.snd_v_embeddings = self.snd_v_embeddings.cuda(gpu_id)
+ self.snd_state_sum_u = self.snd_state_sum_u.to(self.device)
+ self.snd_state_sum_v = self.snd_state_sum_v.to(self.device)
+
+ def fast_sigmoid(self, score):
+ """ do fast sigmoid by looking up in a pre-defined table """
+ idx = torch.floor((score + 6.01) / 0.01).long()
+ return self.lookup_table[idx]
+
+ def fast_logsigmoid(self, score):
+ """ do fast logsigmoid by looking up in a pre-defined table """
+ idx = torch.floor((score + 6.01) / 0.01).long()
+ return self.logsigmoid_table[idx]
+
+ def fast_pos_bp(self, emb_pos_u, emb_pos_v, first_flag):
+ """ get grad for positve samples """
+ pos_score = torch.sum(torch.mul(emb_pos_u, emb_pos_v), dim=1)
+ pos_score = torch.clamp(pos_score, max=6, min=-6)
+ # [batch_size, 1]
+ score = (1 - self.fast_sigmoid(pos_score)).unsqueeze(1)
+ if self.record_loss:
+ if first_flag:
+ self.loss_fst.append(torch.mean(self.fast_logsigmoid(pos_score)).item())
+ else:
+ self.loss_snd.append(torch.mean(self.fast_logsigmoid(pos_score)).item())
+
+ # [batch_size, dim]
+ if self.lap_norm > 0:
+ grad_u_pos = score * emb_pos_v + self.lap_norm * (emb_pos_v - emb_pos_u)
+ grad_v_pos = score * emb_pos_u + self.lap_norm * (emb_pos_u - emb_pos_v)
+ else:
+ grad_u_pos = score * emb_pos_v
+ grad_v_pos = score * emb_pos_u
+
+ return grad_u_pos, grad_v_pos
+
+ def fast_neg_bp(self, emb_neg_u, emb_neg_v, first_flag):
+ """ get grad for negative samples """
+ neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1)
+ neg_score = torch.clamp(neg_score, max=6, min=-6)
+ # [batch_size * negative, 1]
+ score = - self.fast_sigmoid(neg_score).unsqueeze(1)
+ if self.record_loss:
+ if first_flag:
+ self.loss_fst.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item())
+ else:
+ self.loss_snd.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item())
+
+ grad_u_neg = self.neg_weight * score * emb_neg_v
+ grad_v_neg = self.neg_weight * score * emb_neg_u
+
+ return grad_u_neg, grad_v_neg
+
+ def fast_learn(self, batch_edges, neg_nodes=None):
+ """ Learn a batch of edges in a fast way. It has the following features:
+ 1. It calculating the gradients directly without the forward operation.
+ 2. It does sigmoid by a looking up table.
+
+ Specifically, for each positive/negative node pair (i,j), the updating procedure is as following:
+ score = self.fast_sigmoid(u_embedding[i].dot(v_embedding[j]))
+ # label = 1 for positive samples; label = 0 for negative samples.
+ u_embedding[i] += (label - score) * v_embedding[j]
+ v_embedding[i] += (label - score) * u_embedding[j]
+
+ Parameters
+ ----------
+ batch_edges list : a list of node sequnces
+ neg_nodes torch.LongTensor : a long tensor of sampled true negative nodes. If neg_nodes is None,
+ then do negative sampling randomly from the nodes in batch_walks as an alternative.
+
+ Usage example
+ -------------
+ batch_walks = torch.LongTensor([[1,2], [3,4], [5,6]])
+ neg_nodes = None
+ """
+ lr = self.lr
+
+ # [batch_size, 2]
+ nodes = batch_edges
+ if self.only_gpu:
+ nodes = nodes.to(self.device)
+ if neg_nodes is not None:
+ neg_nodes = neg_nodes.to(self.device)
+ bs = len(nodes)
+
+ if self.fst:
+ emb_u = self.fst_u_embeddings(nodes[:, 0]).view(-1, self.emb_dimension).to(self.device)
+ emb_v = self.fst_u_embeddings(nodes[:, 1]).view(-1, self.emb_dimension).to(self.device)
+
+ ## Postive
+ emb_pos_u, emb_pos_v = emb_u, emb_v
+ grad_u_pos, grad_v_pos = self.fast_pos_bp(emb_pos_u, emb_pos_v, True)
+
+ ## Negative
+ emb_neg_u = emb_pos_u.repeat((self.negative, 1))
+
+ if bs < self.batch_size:
+ index_emb_negu, index_emb_negv = init_emb2neg_index(self.negative, bs)
+ index_emb_negu = index_emb_negu.to(self.device)
+ index_emb_negv = index_emb_negv.to(self.device)
+ else:
+ index_emb_negu = self.index_emb_negu
+ index_emb_negv = self.index_emb_negv
+
+ if neg_nodes is None:
+ emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)
+ else:
+ emb_neg_v = self.fst_u_embeddings.weight[neg_nodes].to(self.device)
+
+ grad_u_neg, grad_v_neg = self.fast_neg_bp(emb_neg_u, emb_neg_v, True)
+
+ ## Update
+ grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)
+ grad_u = grad_u_pos
+ if neg_nodes is None:
+ grad_v_pos.index_add_(0, index_emb_negv, grad_v_neg)
+ grad_v = grad_v_pos
+ else:
+ grad_v = grad_v_pos
+
+ # use adam optimizer
+ grad_u = adam(grad_u, self.fst_state_sum_u, nodes[:, 0], lr, self.device, self.only_gpu)
+ grad_v = adam(grad_v, self.fst_state_sum_u, nodes[:, 1], lr, self.device, self.only_gpu)
+ if neg_nodes is not None:
+ grad_v_neg = adam(grad_v_neg, self.fst_state_sum_u, neg_nodes, lr, self.device, self.only_gpu)
+
+ if self.mixed_train:
+ grad_u = grad_u.cpu()
+ grad_v = grad_v.cpu()
+ if neg_nodes is not None:
+ grad_v_neg = grad_v_neg.cpu()
+ else:
+ grad_v_neg = None
+
+ if self.async_update:
+ grad_u.share_memory_()
+ grad_v.share_memory_()
+ nodes.share_memory_()
+ if neg_nodes is not None:
+ neg_nodes.share_memory_()
+ grad_v_neg.share_memory_()
+ self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes, True))
+
+ if not self.async_update:
+ self.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u)
+ self.fst_u_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v)
+ if neg_nodes is not None:
+ self.fst_u_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg)
+
+ if self.snd:
+ emb_u = self.snd_u_embeddings(nodes[:, 0]).view(-1, self.emb_dimension).to(self.device)
+ emb_v = self.snd_v_embeddings(nodes[:, 1]).view(-1, self.emb_dimension).to(self.device)
+
+ ## Postive
+ emb_pos_u, emb_pos_v = emb_u, emb_v
+ grad_u_pos, grad_v_pos = self.fast_pos_bp(emb_pos_u, emb_pos_v, False)
+
+ ## Negative
+ emb_neg_u = emb_pos_u.repeat((self.negative, 1))
+
+ if bs < self.batch_size:
+ index_emb_negu, index_emb_negv = init_emb2neg_index(self.negative, bs)
+ index_emb_negu = index_emb_negu.to(self.device)
+ index_emb_negv = index_emb_negv.to(self.device)
+ else:
+ index_emb_negu = self.index_emb_negu
+ index_emb_negv = self.index_emb_negv
+
+ if neg_nodes is None:
+ emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)
+ else:
+ emb_neg_v = self.snd_v_embeddings.weight[neg_nodes].to(self.device)
+
+ grad_u_neg, grad_v_neg = self.fast_neg_bp(emb_neg_u, emb_neg_v, False)
+
+ ## Update
+ grad_u_pos.index_add_(0, index_emb_negu, grad_u_neg)
+ grad_u = grad_u_pos
+ if neg_nodes is None:
+ grad_v_pos.index_add_(0, index_emb_negv, grad_v_neg)
+ grad_v = grad_v_pos
+ else:
+ grad_v = grad_v_pos
+
+ # use adam optimizer
+ grad_u = adam(grad_u, self.snd_state_sum_u, nodes[:, 0], lr, self.device, self.only_gpu)
+ grad_v = adam(grad_v, self.snd_state_sum_v, nodes[:, 1], lr, self.device, self.only_gpu)
+ if neg_nodes is not None:
+ grad_v_neg = adam(grad_v_neg, self.snd_state_sum_v, neg_nodes, lr, self.device, self.only_gpu)
+
+ if self.mixed_train:
+ grad_u = grad_u.cpu()
+ grad_v = grad_v.cpu()
+ if neg_nodes is not None:
+ grad_v_neg = grad_v_neg.cpu()
+ else:
+ grad_v_neg = None
+
+ if self.async_update:
+ grad_u.share_memory_()
+ grad_v.share_memory_()
+ nodes.share_memory_()
+ if neg_nodes is not None:
+ neg_nodes.share_memory_()
+ grad_v_neg.share_memory_()
+ self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes, False))
+
+ if not self.async_update:
+ self.snd_u_embeddings.weight.data.index_add_(0, nodes[:, 0], grad_u)
+ self.snd_v_embeddings.weight.data.index_add_(0, nodes[:, 1], grad_v)
+ if neg_nodes is not None:
+ self.snd_v_embeddings.weight.data.index_add_(0, neg_nodes, grad_v_neg)
+
+ return
+
+ def get_embedding(self):
+ if self.fst:
+ embedding_fst = self.fst_u_embeddings.weight.cpu().data.numpy()
+ embedding_fst /= np.sqrt(np.sum(embedding_fst * embedding_fst, 1)).reshape(-1, 1)
+ if self.snd:
+ embedding_snd = self.snd_u_embeddings.weight.cpu().data.numpy()
+ embedding_snd /= np.sqrt(np.sum(embedding_snd * embedding_snd, 1)).reshape(-1, 1)
+ if self.fst and self.snd:
+ embedding = np.concatenate((embedding_fst, embedding_snd), 1)
+ embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1)
+ elif self.fst and not self.snd:
+ embedding = embedding_fst
+ elif self.snd and not self.fst:
+ embedding = embedding_snd
+ else:
+ pass
+
+ return embedding
+
+ def save_embedding(self, dataset, file_name):
+ """ Write embedding to local file. Only used when node ids are numbers.
+
+ Parameter
+ ---------
+ dataset DeepwalkDataset : the dataset
+ file_name str : the file name
+ """
+ embedding = self.get_embedding()
+ np.save(file_name, embedding)
+
+ def save_embedding_pt(self, dataset, file_name):
+ """ For ogb leaderboard. """
+ embedding = torch.Tensor(self.get_embedding()).cpu()
+ embedding_empty = torch.zeros_like(embedding.data)
+ valid_nodes = torch.LongTensor(dataset.valid_nodes)
+ valid_embedding = embedding.data.index_select(0, valid_nodes)
+ embedding_empty.index_add_(0, valid_nodes, valid_embedding)
+
+ torch.save(embedding_empty, file_name)
\ No newline at end of file
diff --git a/examples/pytorch/ogb/line/reading_data.py b/examples/pytorch/ogb/line/reading_data.py
new file mode 100644
index 000000000000..7901a83dc985
--- /dev/null
+++ b/examples/pytorch/ogb/line/reading_data.py
@@ -0,0 +1,212 @@
+import os
+import numpy as np
+import scipy.sparse as sp
+import pickle
+import torch
+from torch.utils.data import DataLoader
+from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_archive
+import random
+import time
+import dgl
+
+def ReadTxtNet(file_path="", undirected=True):
+ """ Read the txt network file.
+ Notations: The network is unweighted.
+
+ Parameters
+ ----------
+ file_path str : path of network file
+ undirected bool : whether the edges are undirected
+
+ Return
+ ------
+ net dict : a dict recording the connections in the graph
+ node2id dict : a dict mapping the nodes to their embedding indices
+ id2node dict : a dict mapping nodes embedding indices to the nodes
+ """
+ if file_path == 'youtube' or file_path == 'blog':
+ name = file_path
+ dir = get_download_dir()
+ zip_file_path='{}/{}.zip'.format(dir, name)
+ download(_get_dgl_url(os.path.join('dataset/DeepWalk/', '{}.zip'.format(file_path))), path=zip_file_path)
+ extract_archive(zip_file_path,
+ '{}/{}'.format(dir, name))
+ file_path = "{}/{}/{}-net.txt".format(dir, name, name)
+
+ node2id = {}
+ id2node = {}
+ cid = 0
+
+ src = []
+ dst = []
+ weight = []
+ net = {}
+ with open(file_path, "r") as f:
+ for line in f.readlines():
+ tup = list(map(int, line.strip().split(" ")))
+ assert len(tup) in [2, 3], "The format of network file is unrecognizable."
+ if len(tup) == 3:
+ n1, n2, w = tup
+ elif len(tup) == 2:
+ n1, n2 = tup
+ w = 1
+ if n1 not in node2id:
+ node2id[n1] = cid
+ id2node[cid] = n1
+ cid += 1
+ if n2 not in node2id:
+ node2id[n2] = cid
+ id2node[cid] = n2
+ cid += 1
+
+ n1 = node2id[n1]
+ n2 = node2id[n2]
+ if n1 not in net:
+ net[n1] = {n2: w}
+ src.append(n1)
+ dst.append(n2)
+ weight.append(w)
+ elif n2 not in net[n1]:
+ net[n1][n2] = w
+ src.append(n1)
+ dst.append(n2)
+ weight.append(w)
+
+ if undirected:
+ if n2 not in net:
+ net[n2] = {n1: w}
+ src.append(n2)
+ dst.append(n1)
+ weight.append(w)
+ elif n1 not in net[n2]:
+ net[n2][n1] = w
+ src.append(n2)
+ dst.append(n1)
+ weight.append(w)
+
+ print("node num: %d" % len(net))
+ print("edge num: %d" % len(src))
+ assert max(net.keys()) == len(net) - 1, "error reading net, quit"
+
+ sm = sp.coo_matrix(
+ (np.array(weight), (src, dst)),
+ dtype=np.float32)
+
+ return net, node2id, id2node, sm
+
+def net2graph(net_sm):
+ """ Transform the network to DGL graph
+
+ Return
+ ------
+ G DGLGraph : graph by DGL
+ """
+ start = time.time()
+ G = dgl.DGLGraph(net_sm)
+ end = time.time()
+ t = end - start
+ print("Building DGLGraph in %.2fs" % t)
+ return G
+
+def make_undirected(G):
+ #G.readonly(False)
+ G.add_edges(G.edges()[1], G.edges()[0])
+ return G
+
+def find_connected_nodes(G):
+ nodes = torch.nonzero(G.out_degrees()).squeeze(-1)
+ return nodes
+
+class LineDataset:
+ def __init__(self,
+ net_file,
+ batch_size,
+ num_samples,
+ negative=5,
+ gpus=[0],
+ fast_neg=True,
+ ogbl_name="",
+ load_from_ogbl=False,
+ ogbn_name="",
+ load_from_ogbn=False,
+ ):
+ """ This class has the following functions:
+ 1. Transform the txt network file into DGL graph;
+ 2. Generate random walk sequences for the trainer;
+ 3. Provide the negative table if the user hopes to sample negative
+ nodes according to nodes' degrees;
+
+ Parameter
+ ---------
+ net_file str : path of the dgl network file
+ walk_length int : number of nodes in a sequence
+ window_size int : context window size
+ num_walks int : number of walks for each node
+ batch_size int : number of node sequences in each batch
+ negative int : negative samples for each positve node pair
+ fast_neg bool : whether do negative sampling inside a batch
+ """
+ self.batch_size = batch_size
+ self.negative = negative
+ self.num_samples = num_samples
+ self.num_procs = len(gpus)
+ self.fast_neg = fast_neg
+
+ if load_from_ogbl:
+ assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training."
+ from load_dataset import load_from_ogbl_with_name
+ self.G = load_from_ogbl_with_name(ogbl_name)
+ elif load_from_ogbn:
+ assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training."
+ from load_dataset import load_from_ogbn_with_name
+ self.G = load_from_ogbn_with_name(ogbn_name)
+ else:
+ self.G = dgl.load_graphs(net_file)[0][0]
+ self.G = make_undirected(self.G)
+ print("Finish reading graph")
+
+ self.num_nodes = self.G.number_of_nodes()
+
+ start = time.time()
+ seeds = np.random.choice(np.arange(self.G.number_of_edges()),
+ self.num_samples,
+ replace=True) # edge index
+ self.seeds = torch.split(torch.LongTensor(seeds),
+ int(np.ceil(self.num_samples / self.num_procs)),
+ 0)
+ end = time.time()
+ t = end - start
+ print("generate %d samples in %.2fs" % (len(seeds), t))
+
+ # negative table for true negative sampling
+ self.valid_nodes = find_connected_nodes(self.G)
+ if not fast_neg:
+ node_degree = self.G.out_degrees(self.valid_nodes).numpy()
+ node_degree = np.power(node_degree, 0.75)
+ node_degree /= np.sum(node_degree)
+ node_degree = np.array(node_degree * 1e8, dtype=np.int)
+ self.neg_table = []
+
+ for idx, node in enumerate(self.valid_nodes):
+ self.neg_table += [node] * node_degree[idx]
+ self.neg_table_size = len(self.neg_table)
+ self.neg_table = np.array(self.neg_table, dtype=np.long)
+ del node_degree
+
+ def create_sampler(self, i):
+ """ create random walk sampler """
+ return EdgeSampler(self.G, self.seeds[i])
+
+ def save_mapping(self, map_file):
+ with open(map_file, "wb") as f:
+ pickle.dump(self.node2id, f)
+
+class EdgeSampler(object):
+ def __init__(self, G, seeds):
+ self.G = G
+ self.seeds = seeds
+ self.edges = torch.cat((self.G.edges()[0].unsqueeze(0), self.G.edges()[1].unsqueeze(0)), 0).t()
+
+ def sample(self, seeds):
+ """ seeds torch.LongTensor : a batch of indices of edges """
+ return self.edges[torch.LongTensor(seeds)]
\ No newline at end of file
diff --git a/examples/pytorch/ogb/line/utils.py b/examples/pytorch/ogb/line/utils.py
new file mode 100644
index 000000000000..615378cedb26
--- /dev/null
+++ b/examples/pytorch/ogb/line/utils.py
@@ -0,0 +1,61 @@
+import torch
+from functools import wraps
+from _thread import start_new_thread
+import torch.multiprocessing as mp
+
+def thread_wrapped_func(func):
+ """Wrapped func for torch.multiprocessing.Process.
+ With this wrapper we can use OMP threads in subprocesses
+ otherwise, OMP_NUM_THREADS=1 is mandatory.
+ How to use:
+ @thread_wrapped_func
+ def func_to_wrap(args ...):
+ """
+ @wraps(func)
+ def decorated_function(*args, **kwargs):
+ queue = mp.Queue()
+ def _queue_result():
+ exception, trace, res = None, None, None
+ try:
+ res = func(*args, **kwargs)
+ except Exception as e:
+ exception = e
+ trace = traceback.format_exc()
+ queue.put((res, exception, trace))
+
+ start_new_thread(_queue_result, ())
+ result, exception, trace = queue.get()
+ if exception is None:
+ return result
+ else:
+ assert isinstance(exception, Exception)
+ raise exception.__class__(trace)
+ return decorated_function
+
+def check_args(args):
+ flag = sum([args.only_1st, args.only_2nd])
+ assert flag <= 1, "no more than one selection from --only_1st and --only_2nd"
+ if flag == 0:
+ assert args.dim % 2 == 0, "embedding dimension must be an even number"
+ if args.async_update:
+ assert args.mix, "please use --async_update with --mix"
+
+def sum_up_params(model):
+ """ Count the model parameters """
+ n = []
+ if model.fst:
+ p = model.fst_u_embeddings.weight.cpu().data.numel()
+ n.append(p)
+ p = model.fst_state_sum_u.cpu().data.numel()
+ n.append(p)
+ if model.snd:
+ p = model.snd_u_embeddings.weight.cpu().data.numel() * 2
+ n.append(p)
+ p = model.snd_state_sum_u.cpu().data.numel() * 2
+ n.append(p)
+ n.append(model.lookup_table.cpu().numel())
+ try:
+ n.append(model.index_emb_negu.cpu().numel() * 2)
+ except:
+ pass
+ print("#params " + str(sum(n)))
\ No newline at end of file