forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add gin model * convert dataset.py to data_ont_the_fly way and put it into dgl.data module * convert dataset.py to data_ont_the_fly way and put it into dgl.data module python code checked * modified document and reference TUDataset; checked python part and bypass cpp part due to error * change tensor to numpy in dataset and transform in collate@Dataloader * Change minor format issue Change minor format issue * moved logging; adjusted tqdm etc
- Loading branch information
1 parent
fb6af16
commit a3febc0
Showing
10 changed files
with
911 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# IDE | ||
.idea | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
Graph Isomorphism Network (GIN) | ||
============ | ||
|
||
- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km) | ||
- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns). | ||
|
||
Dependencies | ||
------------ | ||
- PyTorch 1.0.1+ | ||
- sklearn | ||
- tqdm | ||
|
||
``bash | ||
pip install torch sklearn tqdm | ||
`` | ||
|
||
How to run | ||
---------- | ||
|
||
An experiment on the GIN in default settings can be run with | ||
|
||
```bash | ||
python main.py | ||
``` | ||
|
||
An experiment on the GIN in customized settings can be run with | ||
```bash | ||
python main.py [--device 0 | --disable-cuda] --dataset COLLAB \ | ||
--graph_pooling_type max --neighbor_pooling_type sum | ||
``` | ||
|
||
Results | ||
------- | ||
|
||
Run with following with the double SUM pooling way: | ||
(tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI") | ||
```bash | ||
python train.py --dataset MUTAB --device 0 \ | ||
--graph_pooling_type sum --neighbor_pooling_type sum | ||
``` | ||
|
||
* MUTAG: 0.85 (paper: ~0.89) | ||
* COLLAB: 0.89 (paper: ~0.80) | ||
* IMDBBINARY: 0.76 (paper: ~0.75) | ||
* IMDBMULTI: 0.51 (paper: ~0.52) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
PyTorch compatible dataloader | ||
""" | ||
|
||
|
||
import math | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
from sklearn.model_selection import StratifiedKFold | ||
import dgl | ||
|
||
|
||
# default collate function | ||
def collate(samples): | ||
# The input `samples` is a list of pairs (graph, label). | ||
graphs, labels = map(list, zip(*samples)) | ||
for g in graphs: | ||
# deal with node feats | ||
for feat in g.node_attr_schemes().keys(): | ||
# TODO torch.Tensor is not recommended | ||
# torch.DoubleTensor and torch.tensor | ||
# will meet error in executor.py@runtime line 472, tensor.py@backend line 147 | ||
# RuntimeError: expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor | ||
g.ndata[feat] = torch.Tensor(g.ndata[feat]) | ||
# no edge feats | ||
batched_graph = dgl.batch(graphs) | ||
labels = torch.tensor(labels) | ||
return batched_graph, labels | ||
|
||
|
||
class GraphDataLoader(): | ||
def __init__(self, | ||
dataset, | ||
batch_size, | ||
device, | ||
collate_fn=collate, | ||
seed=0, | ||
shuffle=True, | ||
split_name='fold10', | ||
fold_idx=0, | ||
split_ratio=0.7): | ||
|
||
self.shuffle = shuffle | ||
self.seed = seed | ||
self.kwargs = {'pin_memory': True} if 'cuda' in device.type else {} | ||
|
||
labels = [l for _, l in dataset] | ||
|
||
if split_name == 'fold10': | ||
train_idx, valid_idx = self._split_fold10( | ||
labels, fold_idx, seed, shuffle) | ||
elif split_name == 'rand': | ||
train_idx, valid_idx = self._split_rand( | ||
labels, split_ratio, seed, shuffle) | ||
else: | ||
raise NotImplementedError() | ||
|
||
train_sampler = SubsetRandomSampler(train_idx) | ||
valid_sampler = SubsetRandomSampler(valid_idx) | ||
|
||
self.train_loader = DataLoader( | ||
dataset, sampler=train_sampler, | ||
batch_size=batch_size, collate_fn=collate, **self.kwargs) | ||
self.valid_loader = DataLoader( | ||
dataset, sampler=valid_sampler, | ||
batch_size=batch_size, collate_fn=collate, **self.kwargs) | ||
|
||
def train_valid_loader(self): | ||
return self.train_loader, self.valid_loader | ||
|
||
def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True): | ||
''' 10 flod ''' | ||
assert 0 <= fold_idx and fold_idx < 10, print( | ||
"fold_idx must be from 0 to 9.") | ||
|
||
idx_list = [] | ||
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed) | ||
idx_list = [] | ||
for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y) | ||
idx_list.append(idx) | ||
train_idx, valid_idx = idx_list[fold_idx] | ||
|
||
print( | ||
"train_set : test_set = %d : %d", | ||
len(train_idx), len(valid_idx)) | ||
|
||
return train_idx, valid_idx | ||
|
||
def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True): | ||
num_entries = len(labels) | ||
indices = list(range(num_entries)) | ||
np.random.seed(seed) | ||
np.random.shuffle(indices) | ||
split = int(math.floor(split_ratio * num_entries)) | ||
train_idx, valid_idx = indices[:split], indices[split:] | ||
|
||
print( | ||
"train_set : test_set = %d : %d", | ||
len(train_idx), len(valid_idx)) | ||
|
||
return train_idx, valid_idx | ||
|
Oops, something went wrong.