A deep reinforcment learning model to solve the Euclidean Steiner Tree (EST) problem. Training with REINFORCE with greedy rollout baseline.
Implementation of our paper: Deep-Steiner: Learning to Solve the Euclidean Steiner Tree Problem, which is accepted by EAI WiCON 2022.
- Python>=3.8
- NumPy
- SciPy
- PyTorch>=1.7
- tqdm
- tensorboard_logger
- Matplotlib (optional, only for plotting)
Training data is generated on the fly. To generate validation and test data for all problems with seed number:
python generate_data.py --problem all --name validation --seed 4321
python generate_data.py --problem all --name test --seed 1234
For training EST problem instances with 10 nodes and using rollout as REINFORCE baseline and using the generated validation set:
python run.py --graph_size 10 --batch_size 32 --epoch_size 10240 --val_size 10000 --eval_batch_size 10 --baseline rollout --run_name 'est10' --n_epochs 100 --lr_model 0.00000001 --seed 1111 --embedding_dim 128 --hidden_dim 128 --n_encode_layers 5
You can initialize a run using a pretrained model by using the --load_path
option:
python run.py --graph_size 10 --batch_size 32 --epoch_size 10240 --val_size 10000 --eval_batch_size 10 --baseline rollout --run_name 'est10' --n_epochs 100 --lr_model 0.00000001 --seed 1111 --embedding_dim 128 --hidden_dim 128 --n_encode_layers 5 --load_path /content/drive/MyDrive/attention_completeV3.0.1/outputs/tsp_10/arc9/epoch-80.pt
To evaluate a model, you can use the eval.py to output the results. All the generated Steiner tree will be saved in the file "select_a.txt":
python eval.py --graph_size 10 --batch_size 32 --epoch_size 10240 --val_size 10000 --eval_batch_size 10 --baseline rollout --run_name 'est10' --n_epochs 100 --lr_model 0.00000001 --seed 1111 --embedding_dim 128 --hidden_dim 128 --n_encode_layers 5 --load_path /content/drive/MyDrive/attention_completeV3.0.1/outputs/tsp_10/arc9/epoch-80.pt
python run.py -h
python eval.py -h
Thanks to wouterkool / attention-learn-to-route for getting us started with the code for the graph attention model.