Skip to content

Commit

Permalink
[Model] Add GIN Model (dmlc#471)
Browse files Browse the repository at this point in the history
* add gin model

* convert dataset.py to data_ont_the_fly way and put it into dgl.data module

* convert dataset.py to data_ont_the_fly way and put it into dgl.data module
python code checked

* modified document and reference TUDataset; checked python part and bypass cpp part due to error

* change tensor to numpy in dataset and transform in collate@Dataloader

* Change minor format issue

Change minor format issue

* moved logging; adjusted tqdm etc
  • Loading branch information
kitaev-chen authored and VoVAllen committed Apr 17, 2019
1 parent fb6af16 commit a3febc0
Show file tree
Hide file tree
Showing 10 changed files with 911 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IDE
.idea

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Contributors
* [Yizhi Liu](https://github.com/yzhliu): RGCN in MXNet
* [@hbsun2113](https://github.com/hbsun2113): GraphSAGE in Pytorch
* [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch
* [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch

Other improvement
* [Brett Koonce](https://github.com/brettkoonce)
Expand Down
11 changes: 11 additions & 0 deletions docs/source/api/python/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Mini graph classification dataset
.. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes


Graph kernel dataset
````````````````````

Expand All @@ -41,6 +42,16 @@ For more information about the dataset, see `Benchmark Data Sets for Graph Kerne
.. autoclass:: TUDataset
:members: __getitem__, __len__


Graph isomorphism network dataset
```````````````````````````````````

A compact subset of graph kernel dataset

.. autoclass:: GINDataset
:members: __getitem__, __len__


Protein-Protein Interaction dataset
```````````````````````````````````

Expand Down
45 changes: 45 additions & 0 deletions examples/pytorch/gin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Graph Isomorphism Network (GIN)
============

- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km)
- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns).

Dependencies
------------
- PyTorch 1.0.1+
- sklearn
- tqdm

``bash
pip install torch sklearn tqdm
``

How to run
----------

An experiment on the GIN in default settings can be run with

```bash
python main.py
```

An experiment on the GIN in customized settings can be run with
```bash
python main.py [--device 0 | --disable-cuda] --dataset COLLAB \
--graph_pooling_type max --neighbor_pooling_type sum
```

Results
-------

Run with following with the double SUM pooling way:
(tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI")
```bash
python train.py --dataset MUTAB --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum
```

* MUTAG: 0.85 (paper: ~0.89)
* COLLAB: 0.89 (paper: ~0.80)
* IMDBBINARY: 0.76 (paper: ~0.75)
* IMDBMULTI: 0.51 (paper: ~0.52)
105 changes: 105 additions & 0 deletions examples/pytorch/gin/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
PyTorch compatible dataloader
"""


import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import dgl


# default collate function
def collate(samples):
# The input `samples` is a list of pairs (graph, label).
graphs, labels = map(list, zip(*samples))
for g in graphs:
# deal with node feats
for feat in g.node_attr_schemes().keys():
# TODO torch.Tensor is not recommended
# torch.DoubleTensor and torch.tensor
# will meet error in executor.py@runtime line 472, tensor.py@backend line 147
# RuntimeError: expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor
g.ndata[feat] = torch.Tensor(g.ndata[feat])
# no edge feats
batched_graph = dgl.batch(graphs)
labels = torch.tensor(labels)
return batched_graph, labels


class GraphDataLoader():
def __init__(self,
dataset,
batch_size,
device,
collate_fn=collate,
seed=0,
shuffle=True,
split_name='fold10',
fold_idx=0,
split_ratio=0.7):

self.shuffle = shuffle
self.seed = seed
self.kwargs = {'pin_memory': True} if 'cuda' in device.type else {}

labels = [l for _, l in dataset]

if split_name == 'fold10':
train_idx, valid_idx = self._split_fold10(
labels, fold_idx, seed, shuffle)
elif split_name == 'rand':
train_idx, valid_idx = self._split_rand(
labels, split_ratio, seed, shuffle)
else:
raise NotImplementedError()

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

self.train_loader = DataLoader(
dataset, sampler=train_sampler,
batch_size=batch_size, collate_fn=collate, **self.kwargs)
self.valid_loader = DataLoader(
dataset, sampler=valid_sampler,
batch_size=batch_size, collate_fn=collate, **self.kwargs)

def train_valid_loader(self):
return self.train_loader, self.valid_loader

def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):
''' 10 flod '''
assert 0 <= fold_idx and fold_idx < 10, print(
"fold_idx must be from 0 to 9.")

idx_list = []
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y)
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]

print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))

return train_idx, valid_idx

def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):
num_entries = len(labels)
indices = list(range(num_entries))
np.random.seed(seed)
np.random.shuffle(indices)
split = int(math.floor(split_ratio * num_entries))
train_idx, valid_idx = indices[:split], indices[split:]

print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))

return train_idx, valid_idx

Loading

0 comments on commit a3febc0

Please sign in to comment.