-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Fix + batched DGMG (dmlc#175)
* DGMG with batch size 1 * Fix * Adjustment * Fix * Fix * Fix * Fix * Fix has_node and __contains__ * Batched implementation for DGMG * Remove redundant dependency * Adjustment * Fix * Add comments
- Loading branch information
Showing
6 changed files
with
726 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
Learning Deep Generative Models of Graphs | ||
Paper: https://arxiv.org/pdf/1803.03324.pdf | ||
This implementation works with a minibatch of size larger than 1 for training and 1 for inference. | ||
""" | ||
import argparse | ||
import datetime | ||
import time | ||
import torch | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader | ||
from torch.nn.utils import clip_grad_norm_ | ||
|
||
from model_batch import DGMG | ||
|
||
|
||
def main(opts): | ||
t1 = time.time() | ||
|
||
# Setup dataset and data loader | ||
if opts['dataset'] == 'cycles': | ||
from cycles import CycleDataset, CycleModelEvaluation, CyclePrinting | ||
|
||
dataset = CycleDataset(fname=opts['path_to_dataset']) | ||
evaluator = CycleModelEvaluation(v_min=opts['min_size'], | ||
v_max=opts['max_size'], | ||
dir = opts['log_dir']) | ||
printer = CyclePrinting(num_epochs=opts['nepochs'], | ||
num_batches=len(dataset) // opts['batch_size']) | ||
else: | ||
raise ValueError('Unsupported dataset: {}'.format(opts['dataset'])) | ||
|
||
data_loader = DataLoader(dataset, batch_size=opts['batch_size'], shuffle=True, num_workers=0, | ||
collate_fn=dataset.collate_batch) | ||
|
||
# Initialize_model | ||
model = DGMG(v_max=opts['max_size'], | ||
node_hidden_size=opts['node_hidden_size'], | ||
num_prop_rounds=opts['num_propagation_rounds']) | ||
|
||
# Initialize optimizer | ||
if opts['optimizer'] == 'Adam': | ||
optimizer = Adam(model.parameters(), lr=opts['lr']) | ||
else: | ||
raise ValueError('Unsupported argument for the optimizer') | ||
|
||
t2 = time.time() | ||
|
||
# Training | ||
model.train() | ||
for epoch in range(opts['nepochs']): | ||
for batch, data in enumerate(data_loader): | ||
|
||
log_prob = model(batch_size=opts['batch_size'], actions=data) | ||
|
||
loss = - log_prob / opts['batch_size'] | ||
batch_avg_prob = (log_prob / opts['batch_size']).detach().exp() | ||
batch_avg_loss = loss.item() | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
if opts['clip_grad']: | ||
clip_grad_norm_(model.parameters(), opts['clip_bound']) | ||
optimizer.step() | ||
|
||
printer.update(epoch + 1, {'averaged loss': batch_avg_loss, | ||
'averaged prob': batch_avg_prob}) | ||
|
||
t3 = time.time() | ||
|
||
model.eval() | ||
evaluator.rollout_and_examine(model, opts['num_generated_samples']) | ||
evaluator.write_summary() | ||
|
||
t4 = time.time() | ||
|
||
print('It took {} to setup.'.format(datetime.timedelta(seconds=t2-t1))) | ||
print('It took {} to finish training.'.format(datetime.timedelta(seconds=t3-t2))) | ||
print('It took {} to finish evaluation.'.format(datetime.timedelta(seconds=t4-t3))) | ||
print('--------------------------------------------------------------------------') | ||
print('On average, an epoch takes {}.'.format(datetime.timedelta( | ||
seconds=(t3-t2) / opts['nepochs']))) | ||
|
||
del model.g_list | ||
torch.save(model, './model_batched.pth') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='batched DGMG') | ||
|
||
# configure | ||
parser.add_argument('--seed', type=int, default=9284, help='random seed') | ||
|
||
# dataset | ||
parser.add_argument('--dataset', choices=['cycles'], default='cycles', | ||
help='dataset to use') | ||
parser.add_argument('--path-to-dataset', type=str, default='cycles.p', | ||
help='load the dataset if it exists, ' | ||
'generate it and save to the path otherwise') | ||
|
||
# log | ||
parser.add_argument('--log-dir', default='./results', | ||
help='folder to save info like experiment configuration ' | ||
'or model evaluation results') | ||
|
||
# optimization | ||
parser.add_argument('--batch-size', type=int, default=10, | ||
help='batch size to use for training') | ||
parser.add_argument('--clip-grad', action='store_true', default=True, | ||
help='gradient clipping is required to prevent gradient explosion') | ||
parser.add_argument('--clip-bound', type=float, default=0.25, | ||
help='constraint of gradient norm for gradient clipping') | ||
|
||
args = parser.parse_args() | ||
from utils import setup | ||
opts = setup(args) | ||
|
||
main(opts) |
Oops, something went wrong.