-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Pytorch Seal example (dmlc#2638)
* add seal example * 1. add paper infomation in examples/README 2. adjust codes 3. option test * use latest `to_simple` to replace coalesce graph function * remove outdated codes * remove useless comment
- Loading branch information
Showing
7 changed files
with
932 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# DGL Implementation of the SEAL Paper | ||
This DGL example implements the link prediction model proposed in the paper | ||
[Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf) | ||
and [REVISITING GRAPH NEURAL NETWORKS FOR LINK PREDICTION](https://arxiv.org/pdf/2010.16103.pdf) | ||
The author's codes of implementation is in [SEAL](https://github.com/muhanzhang/SEAL) (pytorch) | ||
and [SEAL_ogb](https://github.com/facebookresearch/SEAL_OGB) (torch_geometric) | ||
|
||
Example implementor | ||
---------------------- | ||
This example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab. | ||
|
||
The graph dataset used in this example | ||
--------------------------------------- | ||
|
||
ogbl-collab | ||
- NumNodes: 235868 | ||
- NumEdges: 2358104 | ||
- NumNodeFeats: 128 | ||
- NumEdgeWeights: 1 | ||
- NumValidEdges: 160084 | ||
- NumTestEdges: 146329 | ||
|
||
Dependencies | ||
-------------------------------- | ||
|
||
- python 3.6+ | ||
- Pytorch 1.5.0+ | ||
- dgl 0.6.0 + | ||
- ogb | ||
- pandas | ||
- tqdm | ||
- scipy | ||
|
||
|
||
How to run example files | ||
-------------------------------- | ||
In the seal_dgl folder | ||
run on cpu: | ||
```shell script | ||
python main.py --gpu_id=-1 --subsample_ratio=0.1 | ||
``` | ||
run on gpu: | ||
```shell script | ||
python main.py --gpu_id=0 --subsample_ratio=0.1 | ||
``` | ||
|
||
Performance | ||
------------------------- | ||
experiment on `ogbl-collab` | ||
|
||
| method | valid-hits@50 | test-hits@50 | | ||
| ------ | ------------- | ------------ | | ||
| paper | 63.89(0.49) | 53.71(0.47) | | ||
| ours | 63.56(0.71) | 53.61(0.78) | | ||
|
||
Note: We only perform 5 trails in the experiment. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import logging | ||
import time | ||
import os | ||
|
||
|
||
def _transform_log_level(str_level): | ||
if str_level == 'info': | ||
return logging.INFO | ||
elif str_level == 'warning': | ||
return logging.WARNING | ||
elif str_level == 'critical': | ||
return logging.CRITICAL | ||
elif str_level == 'debug': | ||
return logging.DEBUG | ||
elif str_level == 'error': | ||
return logging.ERROR | ||
else: | ||
raise KeyError('Log level error') | ||
|
||
|
||
class LightLogging(object): | ||
def __init__(self, log_path=None, log_name='lightlog', log_level='debug'): | ||
|
||
log_level = _transform_log_level(log_level) | ||
|
||
if log_path: | ||
if not log_path.endswith('/'): | ||
log_path += '/' | ||
if not os.path.exists(log_path): | ||
os.mkdir(log_path) | ||
|
||
if log_name.endswith('-') or log_name.endswith('_'): | ||
log_name = log_path+log_name + time.strftime('%Y-%m-%d-%H:%M', time.localtime(time.time())) + '.log' | ||
else: | ||
log_name = log_path+log_name + '_' + time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time())) + '.log' | ||
|
||
logging.basicConfig(level=log_level, | ||
format="%(asctime)s %(levelname)s: %(message)s", | ||
datefmt='%Y-%m-%d-%H:%M', | ||
handlers=[ | ||
logging.FileHandler(log_name, mode='w'), | ||
logging.StreamHandler() | ||
]) | ||
logging.info('Start Logging') | ||
logging.info('Log file path: {}'.format(log_name)) | ||
|
||
else: | ||
logging.basicConfig(level=log_level, | ||
format="%(asctime)s %(levelname)s: %(message)s", | ||
datefmt='%Y-%m-%d-%H:%M', | ||
handlers=[ | ||
logging.StreamHandler() | ||
]) | ||
logging.info('Start Logging') | ||
|
||
def debug(self, msg): | ||
logging.debug(msg) | ||
|
||
def info(self, msg): | ||
logging.info(msg) | ||
|
||
def critical(self, msg): | ||
logging.critical(msg) | ||
|
||
def warning(self, msg): | ||
logging.warning(msg) | ||
|
||
def error(self, msg): | ||
logging.error(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import time | ||
from tqdm import tqdm | ||
import numpy as np | ||
import torch | ||
from torch.nn import BCEWithLogitsLoss | ||
from dgl import NID, EID | ||
from dgl.dataloading import GraphDataLoader | ||
from utils import parse_arguments | ||
from utils import load_ogb_dataset, evaluate_hits | ||
from sampler import SEALData | ||
from model import GCN, DGCNN | ||
from logger import LightLogging | ||
|
||
''' | ||
Part of the code are adapted from | ||
https://github.com/facebookresearch/SEAL_OGB | ||
''' | ||
|
||
|
||
def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_graphs=None): | ||
model.train() | ||
|
||
total_loss = 0 | ||
for g, labels in tqdm(dataloader, ncols=100): | ||
g = g.to(device) | ||
labels = labels.to(device) | ||
optimizer.zero_grad() | ||
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID]) | ||
loss = loss_fn(logits, labels) | ||
loss.backward() | ||
optimizer.step() | ||
total_loss += loss.item() * num_graphs | ||
|
||
return total_loss / total_graphs | ||
|
||
|
||
@torch.no_grad() | ||
def evaluate(model, dataloader, device): | ||
model.eval() | ||
|
||
y_pred, y_true = [], [] | ||
for g, labels in tqdm(dataloader, ncols=100): | ||
g = g.to(device) | ||
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID]) | ||
y_pred.append(logits.view(-1).cpu()) | ||
y_true.append(labels.view(-1).cpu().to(torch.float)) | ||
|
||
y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) | ||
pos_pred = y_pred[y_true == 1] | ||
neg_pred = y_pred[y_true == 0] | ||
|
||
return pos_pred, neg_pred | ||
|
||
|
||
def main(args, print_fn=print): | ||
print_fn("Experiment arguments: {}".format(args)) | ||
|
||
if args.random_seed: | ||
torch.manual_seed(args.random_seed) | ||
else: | ||
torch.manual_seed(123) | ||
# Load dataset | ||
if args.dataset.startswith('ogbl'): | ||
graph, split_edge = load_ogb_dataset(args.dataset) | ||
else: | ||
raise NotImplementedError | ||
|
||
num_nodes = graph.num_nodes() | ||
|
||
# set gpu | ||
if args.gpu_id >= 0 and torch.cuda.is_available(): | ||
device = 'cuda:{}'.format(args.gpu_id) | ||
else: | ||
device = 'cpu' | ||
|
||
if args.dataset == 'ogbl-collab': | ||
# ogbl-collab dataset is multi-edge graph | ||
use_coalesce = True | ||
else: | ||
use_coalesce = False | ||
|
||
# Generate positive and negative edges and corresponding labels | ||
# Sampling subgraphs and generate node labeling features | ||
seal_data = SEALData(g=graph, split_edge=split_edge, hop=args.hop, neg_samples=args.neg_samples, | ||
subsample_ratio=args.subsample_ratio, use_coalesce=use_coalesce, prefix=args.dataset, | ||
save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print_fn) | ||
node_attribute = seal_data.ndata['feat'] | ||
edge_weight = seal_data.edata['edge_weight'].float() | ||
|
||
train_data = seal_data('train') | ||
val_data = seal_data('valid') | ||
test_data = seal_data('test') | ||
|
||
train_graphs = len(train_data.graph_list) | ||
|
||
# Set data loader | ||
|
||
train_loader = GraphDataLoader(train_data, batch_size=args.batch_size, num_workers=args.num_workers) | ||
val_loader = GraphDataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers) | ||
test_loader = GraphDataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers) | ||
|
||
# set model | ||
if args.model == 'gcn': | ||
model = GCN(num_layers=args.num_layers, | ||
hidden_units=args.hidden_units, | ||
gcn_type=args.gcn_type, | ||
pooling_type=args.pooling, | ||
node_attributes=node_attribute, | ||
edge_weights=edge_weight, | ||
node_embedding=None, | ||
use_embedding=True, | ||
num_nodes=num_nodes, | ||
dropout=args.dropout) | ||
elif args.model == 'dgcnn': | ||
model = DGCNN(num_layers=args.num_layers, | ||
hidden_units=args.hidden_units, | ||
k=args.sort_k, | ||
gcn_type=args.gcn_type, | ||
node_attributes=node_attribute, | ||
edge_weights=edge_weight, | ||
node_embedding=None, | ||
use_embedding=True, | ||
num_nodes=num_nodes, | ||
dropout=args.dropout) | ||
else: | ||
raise ValueError('Model error') | ||
|
||
model = model.to(device) | ||
parameters = model.parameters() | ||
optimizer = torch.optim.Adam(parameters, lr=args.lr) | ||
loss_fn = BCEWithLogitsLoss() | ||
print_fn("Total parameters: {}".format(sum([p.numel() for p in model.parameters()]))) | ||
|
||
# train and evaluate loop | ||
summary_val = [] | ||
summary_test = [] | ||
for epoch in range(args.epochs): | ||
start_time = time.time() | ||
loss = train(model=model, | ||
dataloader=train_loader, | ||
loss_fn=loss_fn, | ||
optimizer=optimizer, | ||
device=device, | ||
num_graphs=args.batch_size, | ||
total_graphs=train_graphs) | ||
train_time = time.time() | ||
if epoch % args.eval_steps == 0: | ||
val_pos_pred, val_neg_pred = evaluate(model=model, | ||
dataloader=val_loader, | ||
device=device) | ||
test_pos_pred, test_neg_pred = evaluate(model=model, | ||
dataloader=test_loader, | ||
device=device) | ||
|
||
val_metric = evaluate_hits(args.dataset, val_pos_pred, val_neg_pred, args.hits_k) | ||
test_metric = evaluate_hits(args.dataset, test_pos_pred, test_neg_pred, args.hits_k) | ||
evaluate_time = time.time() | ||
print_fn("Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, " | ||
"cost time: train-{:.1f}s, total-{:.1f}s".format(epoch, loss, args.hits_k, val_metric, test_metric, | ||
train_time - start_time, | ||
evaluate_time - start_time)) | ||
summary_val.append(val_metric) | ||
summary_test.append(test_metric) | ||
|
||
summary_test = np.array(summary_test) | ||
|
||
print_fn("Experiment Results:") | ||
print_fn("Best hits@{}: {:.4f}, epoch: {}".format(args.hits_k, np.max(summary_test), np.argmax(summary_test))) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_arguments() | ||
logger = LightLogging(log_name='SEAL', log_path='./logs') | ||
main(args, logger.info) |
Oops, something went wrong.