Skip to content

Commit

Permalink
Major refactoring of code, added VRP model, added Pointer Network
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterkool committed Jun 22, 2018
1 parent 4e4ca00 commit 1935bea
Show file tree
Hide file tree
Showing 48 changed files with 2,057 additions and 182 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Attention Solves Your TSP
# Attention Solves Your TSP, Approximately

Attention based model for learning to solve the Travelling Salesman Problem. Training with REINFORCE with greedy rollout baseline.
Attention based model for learning to solve the Travelling Salesman Problem (TSP) and the Vehicle Routing Problem (VRP). Training with REINFORCE with greedy rollout baseline.

## Paper
Please see our paper [Attention Solves Your TSP](https://arxiv.org/abs/1803.08475).
Please see our paper [Attention Solves Your TSP, Approximately](https://arxiv.org/abs/1803.08475).

## Dependencies

Expand All @@ -13,6 +13,7 @@ Please see our paper [Attention Solves Your TSP](https://arxiv.org/abs/1803.0847
* [PyTorch](http://pytorch.org/)=0.3
* tqdm
* [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger)
* Matplotlib (optional, only for plotting)

## Usage

Expand All @@ -33,13 +34,21 @@ To evaluate a model, use the `--load_path` option to specify the model to load a
python run.py --graph_size 20 --eval_only --load_path 'outputs/tsp_20/tsp20_rollout_{datetime}/epoch-0.pt'
```

To load a pretrained model (single GPU only since it cannot load into `DataParallel`):
To load a pretrained model:
```bash
CUDA_VISIBLE_DEVICES=0 python run.py --graph_size 100 --eval_only --load_path pretrained/tsp100.pt
CUDA_VISIBLE_DEVICES=0 python run.py --graph_size 100 --eval_only --load_path pretrained/tsp_100/epoch-99.pt
```
Note that the results may differ slightly from the results reported in the paper, as a different test set was used than the validation set (which depends on the random seed).

For other options and help:
```bash
python run.py -h
```

## Example CVRP solution
See `plot_vrp.ipynb` for an example of loading a pretrained model and plotting the result for Capacitated VRP with 100 nodes.

![CVRP100](images/cvrp_0.png)

## Acknowledgements
Thanks to [pemami4911/neural-combinatorial-rl-pytorch](https://github.com/pemami4911/neural-combinatorial-rl-pytorch) for getting me started with the code for the Pointer Network.
444 changes: 347 additions & 97 deletions attention_model.py

Large diffs are not rendered by default.

65 changes: 60 additions & 5 deletions baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.autograd import Variable
from scipy.stats import ttest_rel
import copy
from train import rollout
from train import rollout, get_inner_model

class Baseline(object):

Expand All @@ -29,6 +29,54 @@ def load_state_dict(self, state_dict):
pass


class WarmupBaseline(Baseline):

def __init__(self, baseline, n_epochs=1, warmup_exp_beta=0.8, ):
super(Baseline, self).__init__()

self.baseline = baseline
assert n_epochs > 0, "n_epochs to warmup must be positive"
self.warmup_baseline = ExponentialBaseline(warmup_exp_beta)
self.alpha = 0
self.n_epochs = n_epochs

def wrap_dataset(self, dataset):
if self.alpha > 0:
return self.baseline.wrap_dataset(dataset)
return self.warmup_baseline.wrap_dataset(dataset)

def unwrap_batch(self, batch):
if self.alpha > 0:
return self.baseline.unwrap_batch(batch)
return self.warmup_baseline.unwrap_batch(batch)

def eval(self, x, c):

if self.alpha == 1:
return self.baseline.eval(x, c)
if self.alpha == 0:
return self.warmup_baseline.eval(x, c)
v, l = self.baseline.eval(x, c)
vw, lw = self.warmup_baseline.eval(x, c)
# Return convex combination of baseline and of loss
return self.alpha * v * (1 - self.alpha) * vw, self.alpha * l + (1 - self.alpha * lw)

def epoch_callback(self, model, epoch):
# Need to call epoch callback of inner model (also after first epoch if we have not used it)
self.baseline.epoch_callback(model, epoch)
self.alpha = (epoch + 1) / float(self.n_epochs)
if epoch < self.n_epochs:
print("Set warmup alpha = {}".format(self.alpha))

def state_dict(self):
# Checkpointing within warmup stage makes no sense, only save inner baseline
return self.baseline.state_dict()

def load_state_dict(self, state_dict):
# Checkpointing within warmup stage makes no sense, only load inner baseline
self.baseline.load_state_dict(state_dict)


class NoBaseline(Baseline):

def eval(self, x, c):
Expand Down Expand Up @@ -82,11 +130,14 @@ def epoch_callback(self, model, epoch):

def state_dict(self):
return {
'critic': self.critic
'critic': self.critic.state_dict()
}

def load_state_dict(self, state_dict):
self.critic.load_state_dict({**self.critic.state_dict(), **state_dict.get('critic', {})})
critic_state_dict = state_dict.get('critic', {})
if not isinstance(critic_state_dict, dict): # backwards compatibility
critic_state_dict = critic_state_dict.state_dict()
self.critic.load_state_dict({**self.critic.state_dict(), **critic_state_dict})


class RolloutBaseline(Baseline):
Expand Down Expand Up @@ -122,7 +173,7 @@ def unwrap_batch(self, batch):

def eval(self, x, c):
# Use volatile mode for efficient inference (single batch so we do not use rollout function)
v, _, _, _ = self.model(Variable(x.data, volatile=True))
v, _ = self.model(Variable(x.data, volatile=True))

v.volatile = False # The returned value should not be volatile

Expand Down Expand Up @@ -161,7 +212,11 @@ def state_dict(self):
}

def load_state_dict(self, state_dict):
self._update_model(state_dict['model'], state_dict['epoch'], state_dict['dataset'])
# TODO change code such that not model but model parameters are saved
# We make it such that it works whether model was saved as data parallel or not
load_model = copy.deepcopy(self.model)
get_inner_model(load_model).load_state_dict(get_inner_model(state_dict['model']).state_dict())
self._update_model(load_model, state_dict['epoch'], state_dict['dataset'])


class BaselineDataset(Dataset):
Expand Down
25 changes: 25 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import pickle


def check_extension(filename):
if os.path.splitext(filename)[1] != ".pkl":
return filename + ".pkl"
return filename


def save_dataset(dataset, filename):

filedir = os.path.split(filename)[0]

if not os.path.isdir(filedir):
os.makedirs(filedir)

with open(check_extension(filename), 'wb') as f:
pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)


def load_dataset(filename):

with open(check_extension(filename), 'rb') as f:
return pickle.load(f)
51 changes: 51 additions & 0 deletions generate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse
import os
import numpy as np
from data_utils import check_extension, save_dataset


def generate_tsp_data(dataset_size, tsp_size):
return np.random.uniform(size=(dataset_size, tsp_size, 2)).tolist()


def generate_vrp_data(dataset_size, vrp_size):
CAPACITIES = {
10: 20.,
20: 30.,
50: 40.,
100: 50.
}
return list(zip(
np.random.uniform(size=(dataset_size, 2)).tolist(), # Depot location
np.random.uniform(size=(dataset_size, vrp_size, 2)).tolist(), # Node locations
np.random.randint(1, 10, size=(dataset_size, vrp_size)), # Demand, uniform integer 1 ... 9
np.ones(dataset_size) * CAPACITIES[vrp_size] # Capacity, same for whole dataset
))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("dataset", help="Filename of the dataset to create")
parser.add_argument("--problem", type=str, default='tsp', help="Problem, 'tsp' or 'vrp'")
parser.add_argument("--dataset_size", type=int, default=10000, help="Size of the dataset")
parser.add_argument('--graph_size', type=int, default=20, help="Size of problem instances")
parser.add_argument("-f", action='store_true', help="Set true to overwrite")
parser.add_argument('--seed', type=int, default=None, help="Random seed")

opts = parser.parse_args()

assert opts.f or not os.path.isfile(check_extension(opts.dataset)), \
"File already exists! Try running with -f option to overwrite."

np.random.seed(opts.seed)
if opts.problem == 'tsp':
dataset = generate_tsp_data(opts.dataset_size, opts.graph_size)
elif opts.problem == 'vrp':
dataset = generate_vrp_data(opts.dataset_size, opts.graph_size)
else:
assert False, "Unknown problem: {}".format(opts.problem)

print(dataset[0])
filename = check_extension(opts.dataset)

save_dataset(dataset, opts.dataset)
2 changes: 1 addition & 1 deletion graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, embed_dim, normalization='batch'):
self.normalizer = normalizer_class(embed_dim, affine=True)

# Normalization by default initializes affine parameters with bias 0 and weight unif(0,1) which is too large!
self.init_parameters()
# self.init_parameters()

def init_parameters(self):

Expand Down
Binary file added images/cvrp_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/cvrp_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 9 additions & 14 deletions log_utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
from tensorboard_logger import log_value


def log_values(cost, grad_norms, epoch, batch_id, step,
log_likelihood, reinforce_loss, bl_loss, opts):
log_likelihood, reinforce_loss, bl_loss, tb_logger, opts):
avg_cost = cost.mean().data[0]
grad_norms, grad_norms_clipped = grad_norms

# Log values to screen
print('epoch: {}, train_batch_id: {}, avg_cost: {}'.format(epoch, batch_id, avg_cost))

print('grad_norm: {}, clipped: {}'.format(grad_norms[0], grad_norms_clipped[0]))
if opts.baseline == 'critic':
print('grad_norm_critic: {}, clipped: {}'.format(grad_norms[1], grad_norms_clipped[1]))

# Log values to tensorboard
if not opts.no_tensorboard:
log_value('avg_cost', avg_cost, step)
tb_logger.log_value('avg_cost', avg_cost, step)

log_value('actor_loss', reinforce_loss.data[0], step)
log_value('nll', -log_likelihood.mean().data[0], step)
tb_logger.log_value('actor_loss', reinforce_loss.data[0], step)
tb_logger.log_value('nll', -log_likelihood.mean().data[0], step)

log_value('grad_norm', grad_norms[0], step)
log_value('grad_norm_clipped', grad_norms_clipped[0], step)
tb_logger.log_value('grad_norm', grad_norms[0], step)
tb_logger.log_value('grad_norm_clipped', grad_norms_clipped[0], step)

if opts.baseline == 'critic':
log_value('critic_loss', bl_loss.data[0], step)
log_value('critic_grad_norm', grad_norms[1], step)
log_value('critic_grad_norm_clipped', grad_norms_clipped[1], step)
tb_logger.log_value('critic_loss', bl_loss.data[0], step)
tb_logger.log_value('critic_grad_norm', grad_norms[1], step)
tb_logger.log_value('critic_grad_norm_clipped', grad_norms_clipped[1], step)
22 changes: 17 additions & 5 deletions options.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,29 @@ def get_options(args=None):
description="Attention based model for solving the Travelling Salesman Problem with Reinforcement Learning")

# Data
parser.add_argument('--problem', default='tsp', help="The problem to solve, default 'tsp'")
parser.add_argument('--graph_size', type=int, default=20, help="The size of the problem graph")
parser.add_argument('--batch_size', type=int, default=512, help='Number of instances per batch during training')
parser.add_argument('--epoch_size', type=int, default=1280000, help='Number of instances per epoch during training')
parser.add_argument('--val_size', type=int, default=10000,
help='Number of instances used for reporting validation performance')
parser.add_argument('--val_dataset', type=str, default=None, help='Dataset file to use for validation')

# Model
parser.add_argument('--model', default='attention', help="Model, 'attention' (default) or 'pointer'")
parser.add_argument('--embedding_dim', type=int, default=128, help='Dimension of input embedding')
parser.add_argument('--hidden_dim', type=int, default=128, help='Dimension of hidden layers in Enc/Dec')
parser.add_argument('--n_encode_layers', type=int, default=3,
help='Number of process block iters to run in the Critic network')
help='Number of layers in the encoder/critic network')
parser.add_argument('--tanh_clipping', type=float, default=10.,
help='Clip the parameters to within +- this value using tanh. '
'Set to 0 to not perform any clipping.')
parser.add_argument('--normalization', default='batch', help="Normalization type, 'batch' (default) or 'instance'")

# Training
parser.add_argument('--lr_model', default=1e-3, help="Set the learning rate for the actor network")
parser.add_argument('--lr_critic', default=1e-3, help="Set the learning rate for the critic network")
parser.add_argument('--lr_decay', default=0.96, help='Learning rate decay per epoch')
parser.add_argument('--lr_model', type=float, default=1e-4, help="Set the learning rate for the actor network")
parser.add_argument('--lr_critic', type=float, default=1e-4, help="Set the learning rate for the critic network")
parser.add_argument('--lr_decay', type=float, default=1.0, help='Learning rate decay per epoch')
parser.add_argument('--eval_only', action='store_true', help='Set this value to only evaluate model')
parser.add_argument('--n_epochs', type=int, default=100, help='The number of epochs to train')
parser.add_argument('--seed', type=int, default=1234, help='Random seed to use')
Expand All @@ -41,6 +44,9 @@ def get_options(args=None):
help="Baseline to use: 'rollout', 'critic' or 'exponential'. Defaults to no baseline.")
parser.add_argument('--bl_alpha', type=float, default=0.05,
help='Significance in the t-test for updating rollout baseline')
parser.add_argument('--bl_warmup_epochs', type=int, default=None,
help='Number of epochs to warmup the baseline, default None means 1 for rollout (exponential '
'used for warmup phase), 0 otherwise. Can only be used with rollout baseline.')
parser.add_argument('--eval_batch_size', type=int, default=1024,
help="Batch size to use during (baseline) evaluation")

Expand All @@ -51,7 +57,10 @@ def get_options(args=None):
parser.add_argument('--output_dir', default='outputs', help='Directory to write output models to')
parser.add_argument('--epoch_start', type=int, default=0,
help='Start at epoch # (relevant for learning rate decay)')
parser.add_argument('--checkpoint_epochs', type=int, default=1,
help='Save checkpoint every n epochs (default 1), 0 to save no checkpoints')
parser.add_argument('--load_path', help='Path to load model parameters and optimizer state from')
parser.add_argument('--resume', help='Resume from previous checkpoint file')
parser.add_argument('--no_tensorboard', action='store_true', help='Disable logging TensorBoard files')
parser.add_argument('--no_progress_bar', action='store_true', help='Disable progress bar')

Expand All @@ -61,8 +70,11 @@ def get_options(args=None):
opts.run_name = "{}_{}".format(opts.run_name, time.strftime("%Y%m%dT%H%M%S"))
opts.save_dir = os.path.join(
opts.output_dir,
"tsp_{}".format(opts.graph_size),
"{}_{}".format(opts.problem, opts.graph_size),
opts.run_name
)
if opts.bl_warmup_epochs is None:
opts.bl_warmup_epochs = 1 if opts.baseline == 'rollout' else 0
assert (opts.bl_warmup_epochs == 0) or (opts.baseline == 'rollout')
assert opts.epoch_size % opts.batch_size == 0, "Epoch size must be integer multiple of batch size!"
return opts
Loading

0 comments on commit 1935bea

Please sign in to comment.