Skip to content

Commit

Permalink
[Example] Pytorch Seal example (dmlc#2638)
Browse files Browse the repository at this point in the history
* add seal example

* 1. add paper infomation in examples/README
2. adjust codes
3. option test

* use latest `to_simple` to replace coalesce graph function

* remove outdated codes

* remove useless comment
  • Loading branch information
Smilexuhc authored Feb 25, 2021
1 parent 0526b88 commit 583aa76
Show file tree
Hide file tree
Showing 7 changed files with 932 additions and 0 deletions.
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ The folder contains example implementations of selected research papers related
| [Dynamic Graph CNN for Learning on Point Clouds](#dgcnnpoint) | | | | | |
| [Supervised Community Detection with Line Graph Neural Networks](#lgnn) | | | | | |
| [Text Generation from Knowledge Graphs with Graph Transformers](#graphwriter) | | | | | |
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |


## 2020
Expand Down Expand Up @@ -239,6 +240,11 @@ The folder contains example implementations of selected research papers related
- Pooling module: [PyTorch](https://docs.dgl.ai/api/python/nn.pytorch.html#sortpooling), [TensorFlow](https://docs.dgl.ai/api/python/nn.tensorflow.html#sortpooling), [MXNet](https://docs.dgl.ai/api/python/nn.mxnet.html#sortpooling)
- Tags: graph classification

- <a name="seal"></a> Zhang et al. Link Prediction Based on Graph Neural Networks. [Paper link](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf).
- Example code: [pytorch](../examples/pytorch/seal)
- Tags: link prediction, sampling


## 2017

- <a name="gcn"></a> Kipf and Welling. Semi-Supervised Classification with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1609.02907).
Expand Down
56 changes: 56 additions & 0 deletions examples/pytorch/seal/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# DGL Implementation of the SEAL Paper
This DGL example implements the link prediction model proposed in the paper
[Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf)
and [REVISITING GRAPH NEURAL NETWORKS FOR LINK PREDICTION](https://arxiv.org/pdf/2010.16103.pdf)
The author's codes of implementation is in [SEAL](https://github.com/muhanzhang/SEAL) (pytorch)
and [SEAL_ogb](https://github.com/facebookresearch/SEAL_OGB) (torch_geometric)

Example implementor
----------------------
This example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab.

The graph dataset used in this example
---------------------------------------

ogbl-collab
- NumNodes: 235868
- NumEdges: 2358104
- NumNodeFeats: 128
- NumEdgeWeights: 1
- NumValidEdges: 160084
- NumTestEdges: 146329

Dependencies
--------------------------------

- python 3.6+
- Pytorch 1.5.0+
- dgl 0.6.0 +
- ogb
- pandas
- tqdm
- scipy


How to run example files
--------------------------------
In the seal_dgl folder
run on cpu:
```shell script
python main.py --gpu_id=-1 --subsample_ratio=0.1
```
run on gpu:
```shell script
python main.py --gpu_id=0 --subsample_ratio=0.1
```

Performance
-------------------------
experiment on `ogbl-collab`

| method | valid-hits@50 | test-hits@50 |
| ------ | ------------- | ------------ |
| paper | 63.89(0.49) | 53.71(0.47) |
| ours | 63.56(0.71) | 53.61(0.78) |

Note: We only perform 5 trails in the experiment.
69 changes: 69 additions & 0 deletions examples/pytorch/seal/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
import time
import os


def _transform_log_level(str_level):
if str_level == 'info':
return logging.INFO
elif str_level == 'warning':
return logging.WARNING
elif str_level == 'critical':
return logging.CRITICAL
elif str_level == 'debug':
return logging.DEBUG
elif str_level == 'error':
return logging.ERROR
else:
raise KeyError('Log level error')


class LightLogging(object):
def __init__(self, log_path=None, log_name='lightlog', log_level='debug'):

log_level = _transform_log_level(log_level)

if log_path:
if not log_path.endswith('/'):
log_path += '/'
if not os.path.exists(log_path):
os.mkdir(log_path)

if log_name.endswith('-') or log_name.endswith('_'):
log_name = log_path+log_name + time.strftime('%Y-%m-%d-%H:%M', time.localtime(time.time())) + '.log'
else:
log_name = log_path+log_name + '_' + time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time())) + '.log'

logging.basicConfig(level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt='%Y-%m-%d-%H:%M',
handlers=[
logging.FileHandler(log_name, mode='w'),
logging.StreamHandler()
])
logging.info('Start Logging')
logging.info('Log file path: {}'.format(log_name))

else:
logging.basicConfig(level=log_level,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt='%Y-%m-%d-%H:%M',
handlers=[
logging.StreamHandler()
])
logging.info('Start Logging')

def debug(self, msg):
logging.debug(msg)

def info(self, msg):
logging.info(msg)

def critical(self, msg):
logging.critical(msg)

def warning(self, msg):
logging.warning(msg)

def error(self, msg):
logging.error(msg)
174 changes: 174 additions & 0 deletions examples/pytorch/seal/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import time
from tqdm import tqdm
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss
from dgl import NID, EID
from dgl.dataloading import GraphDataLoader
from utils import parse_arguments
from utils import load_ogb_dataset, evaluate_hits
from sampler import SEALData
from model import GCN, DGCNN
from logger import LightLogging

'''
Part of the code are adapted from
https://github.com/facebookresearch/SEAL_OGB
'''


def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_graphs=None):
model.train()

total_loss = 0
for g, labels in tqdm(dataloader, ncols=100):
g = g.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * num_graphs

return total_loss / total_graphs


@torch.no_grad()
def evaluate(model, dataloader, device):
model.eval()

y_pred, y_true = [], []
for g, labels in tqdm(dataloader, ncols=100):
g = g.to(device)
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
y_pred.append(logits.view(-1).cpu())
y_true.append(labels.view(-1).cpu().to(torch.float))

y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
pos_pred = y_pred[y_true == 1]
neg_pred = y_pred[y_true == 0]

return pos_pred, neg_pred


def main(args, print_fn=print):
print_fn("Experiment arguments: {}".format(args))

if args.random_seed:
torch.manual_seed(args.random_seed)
else:
torch.manual_seed(123)
# Load dataset
if args.dataset.startswith('ogbl'):
graph, split_edge = load_ogb_dataset(args.dataset)
else:
raise NotImplementedError

num_nodes = graph.num_nodes()

# set gpu
if args.gpu_id >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu_id)
else:
device = 'cpu'

if args.dataset == 'ogbl-collab':
# ogbl-collab dataset is multi-edge graph
use_coalesce = True
else:
use_coalesce = False

# Generate positive and negative edges and corresponding labels
# Sampling subgraphs and generate node labeling features
seal_data = SEALData(g=graph, split_edge=split_edge, hop=args.hop, neg_samples=args.neg_samples,
subsample_ratio=args.subsample_ratio, use_coalesce=use_coalesce, prefix=args.dataset,
save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print_fn)
node_attribute = seal_data.ndata['feat']
edge_weight = seal_data.edata['edge_weight'].float()

train_data = seal_data('train')
val_data = seal_data('valid')
test_data = seal_data('test')

train_graphs = len(train_data.graph_list)

# Set data loader

train_loader = GraphDataLoader(train_data, batch_size=args.batch_size, num_workers=args.num_workers)
val_loader = GraphDataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers)
test_loader = GraphDataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers)

# set model
if args.model == 'gcn':
model = GCN(num_layers=args.num_layers,
hidden_units=args.hidden_units,
gcn_type=args.gcn_type,
pooling_type=args.pooling,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout)
elif args.model == 'dgcnn':
model = DGCNN(num_layers=args.num_layers,
hidden_units=args.hidden_units,
k=args.sort_k,
gcn_type=args.gcn_type,
node_attributes=node_attribute,
edge_weights=edge_weight,
node_embedding=None,
use_embedding=True,
num_nodes=num_nodes,
dropout=args.dropout)
else:
raise ValueError('Model error')

model = model.to(device)
parameters = model.parameters()
optimizer = torch.optim.Adam(parameters, lr=args.lr)
loss_fn = BCEWithLogitsLoss()
print_fn("Total parameters: {}".format(sum([p.numel() for p in model.parameters()])))

# train and evaluate loop
summary_val = []
summary_test = []
for epoch in range(args.epochs):
start_time = time.time()
loss = train(model=model,
dataloader=train_loader,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
num_graphs=args.batch_size,
total_graphs=train_graphs)
train_time = time.time()
if epoch % args.eval_steps == 0:
val_pos_pred, val_neg_pred = evaluate(model=model,
dataloader=val_loader,
device=device)
test_pos_pred, test_neg_pred = evaluate(model=model,
dataloader=test_loader,
device=device)

val_metric = evaluate_hits(args.dataset, val_pos_pred, val_neg_pred, args.hits_k)
test_metric = evaluate_hits(args.dataset, test_pos_pred, test_neg_pred, args.hits_k)
evaluate_time = time.time()
print_fn("Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
"cost time: train-{:.1f}s, total-{:.1f}s".format(epoch, loss, args.hits_k, val_metric, test_metric,
train_time - start_time,
evaluate_time - start_time))
summary_val.append(val_metric)
summary_test.append(test_metric)

summary_test = np.array(summary_test)

print_fn("Experiment Results:")
print_fn("Best hits@{}: {:.4f}, epoch: {}".format(args.hits_k, np.max(summary_test), np.argmax(summary_test)))


if __name__ == '__main__':
args = parse_arguments()
logger = LightLogging(log_name='SEAL', log_path='./logs')
main(args, logger.info)
Loading

0 comments on commit 583aa76

Please sign in to comment.