Skip to content

Commit

Permalink
[Model] Fix + batched DGMG (dmlc#175)
Browse files Browse the repository at this point in the history
* DGMG with batch size 1

* Fix

* Adjustment

* Fix

* Fix

* Fix

* Fix

* Fix has_node and __contains__

* Batched implementation for DGMG

* Remove redundant dependency

* Adjustment

* Fix

* Add comments
  • Loading branch information
mufeili authored Nov 28, 2018
1 parent 5cda368 commit a0d0b1e
Show file tree
Hide file tree
Showing 6 changed files with 726 additions and 19 deletions.
9 changes: 7 additions & 2 deletions examples/pytorch/dgmg/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
This is an implementation of [Learning Deep Generative Models of Graphs](https://arxiv.org/pdf/1803.03324.pdf) by
Yujia Li, Oriol Vinyals, Chris Dyer, Razvan Pascanu, Peter Battaglia.

# Dependency
## Dependency
- Python 3.5.2
- [Pytorch 0.4.1](https://pytorch.org/)
- [Matplotlib 2.2.2](https://matplotlib.org/)

# Usage
## Usage

- Train with batch size 1: `python main.py`
- Train with batch size larger than 1: `python main_batch.py`.

## Acknowledgement

We would like to thank Yujia Li for providing details on the implementation.
30 changes: 16 additions & 14 deletions examples/pytorch/dgmg/cycles.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ def __len__(self):
def __getitem__(self, index):
return self.dataset[index]

def collate(self, batch):
def collate_single(self, batch):
assert len(batch) == 1, 'Currently we do not support batched training'
return batch[0]

def collate_batch(self, batch):
return batch


def dglGraph_to_adj_list(g):
adj_list = {}
Expand All @@ -112,14 +115,6 @@ def __init__(self, v_min, v_max, dir):

self.dir = dir

def _initialize(self):
self.num_samples_examined = 0

self.average_size = 0
self.valid_size_ratio = 0
self.cycle_ratio = 0
self.valid_ratio = 0

def rollout_and_examine(self, model, num_samples):
assert not model.training, 'You need to call model.eval().'

Expand All @@ -132,14 +127,22 @@ def rollout_and_examine(self, model, num_samples):

for i in range(num_samples):
sampled_graph = model()
if isinstance(sampled_graph, list):
# When the model is a batched implementation, a list of
# DGLGraph objects is returned. Note that with model(),
# we generate a single graph as with the non-batched
# implementation. We actually support batched generation
# during the inference so feel free to modify the code.
sampled_graph = sampled_graph[0]

sampled_adj_list = dglGraph_to_adj_list(sampled_graph)
adj_lists_to_plot.append(sampled_adj_list)

generated_graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= generated_graph_size <= self.v_max)
graph_size = sampled_graph.number_of_nodes()
valid_size = (self.v_min <= graph_size <= self.v_max)
cycle = is_cycle(sampled_graph)

num_total_size += generated_graph_size
num_total_size += graph_size

if valid_size:
num_valid_size += 1
Expand All @@ -150,7 +153,7 @@ def rollout_and_examine(self, model, num_samples):
if valid_size and cycle:
num_valid += 1

if len(adj_lists_to_plot) == 4:
if len(adj_lists_to_plot) >= 4:
plot_times += 1
fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2)
axes = {0: ax0, 1: ax1, 2: ax2, 3: ax3}
Expand Down Expand Up @@ -197,7 +200,6 @@ def _format_value(v):
f.write(msg)

print('Saved model evaluation statistics to {}'.format(model_eval_path))
self._initialize()


class CyclePrinting(object):
Expand Down
6 changes: 5 additions & 1 deletion examples/pytorch/dgmg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
Expand All @@ -31,7 +32,7 @@ def main(opts):
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))

data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0,
collate_fn=dataset.collate)
collate_fn=dataset.collate_single)

# Initialize_model
model = DGMG(v_max=opts['max_size'],
Expand Down Expand Up @@ -96,6 +97,9 @@ def main(opts):
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))

del model.g
torch.save(model, './model.pth')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DGMG')
Expand Down
118 changes: 118 additions & 0 deletions examples/pytorch/dgmg/main_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Learning Deep Generative Models of Graphs
Paper: https://arxiv.org/pdf/1803.03324.pdf
This implementation works with a minibatch of size larger than 1 for training and 1 for inference.
"""
import argparse
import datetime
import time
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from model_batch import DGMG


def main(opts):
t1 = time.time()

# Setup dataset and data loader
if opts['dataset'] == 'cycles':
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting

dataset = CycleDataset(fname=opts['path_to_dataset'])
evaluator = CycleModelEvaluation(v_min=opts['min_size'],
v_max=opts['max_size'],
dir = opts['log_dir'])
printer = CyclePrinting(num_epochs=opts['nepochs'],
num_batches=len(dataset) // opts['batch_size'])
else:
raise ValueError('Unsupported dataset: {}'.format(opts['dataset']))

data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0,
collate_fn=dataset.collate_batch)

# Initialize_model
model = DGMG(v_max=opts['max_size'],
node_hidden_size=opts['node_hidden_size'],
num_prop_rounds=opts['num_propagation_rounds'])

# Initialize optimizer
if opts['optimizer'] == 'Adam':
optimizer = Adam(model.parameters(), lr=opts['lr'])
else:
raise ValueError('Unsupported argument for the optimizer')

t2 = time.time()

# Training
model.train()
for epoch in range(opts['nepochs']):
for batch, data in enumerate(data_loader):

log_prob = model(batch_size=opts['batch_size'], actions=data)

loss = - log_prob / opts['batch_size']
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp()
batch_avg_loss = loss.item()

optimizer.zero_grad()
loss.backward()
if opts['clip_grad']:
clip_grad_norm_(model.parameters(), opts['clip_bound'])
optimizer.step()

printer.update(epoch + 1, {'averaged loss': batch_avg_loss,
'averaged prob': batch_avg_prob})

t3 = time.time()

model.eval()
evaluator.rollout_and_examine(model, opts['num_generated_samples'])
evaluator.write_summary()

t4 = time.time()

print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1)))
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2)))
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3)))
print('--------------------------------------------------------------------------')
print('On average, an epoch takes {}.'.format(datetime.timedelta(
seconds=(t3-t2) / opts['nepochs'])))

del model.g_list
torch.save(model, './model_batched.pth')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='batched DGMG')

# configure
parser.add_argument('--seed', type=int, default=9284, help='random seed')

# dataset
parser.add_argument('--dataset', choices=['cycles'], default='cycles',
help='dataset to use')
parser.add_argument('--path-to-dataset', type=str, default='cycles.p',
help='load the dataset if it exists, '
'generate it and save to the path otherwise')

# log
parser.add_argument('--log-dir', default='./results',
help='folder to save info like experiment configuration '
'or model evaluation results')

# optimization
parser.add_argument('--batch-size', type=int, default=10,
help='batch size to use for training')
parser.add_argument('--clip-grad', action='store_true', default=True,
help='gradient clipping is required to prevent gradient explosion')
parser.add_argument('--clip-bound', type=float, default=0.25,
help='constraint of gradient norm for gradient clipping')

args = parser.parse_args()
from utils import setup
opts = setup(args)

main(opts)
Loading

0 comments on commit a0d0b1e

Please sign in to comment.