Code accompanying the paper
The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.
Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0
NOTE: PyTorch 0.4 is not supported at this moment and would lead to OOM.
Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.
The easist way to check whether your environment is compatible with ours is to evaluate the pretrained models.
CIFAR-10 (cifar10_model.pt)
cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
Expected result: 2.63% test error rate with 3.3M model params.
PTB (ptb_model.pt)
cd rnn && python test.py --model_path ptb_model.pt
Expected result: 55.68 test perplexity with 23M model params.
ImageNet (imagenet_model.pt)
cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
Expected result: 26.7% top-1 error and 8.7% top-5 error with 4.7M model params.
To carry out architecture search using 2nd-order approximation, run
cd cnn && python train_search.py --unrolled # for conv cells on CIFAR-10
cd rnn && python train_search.py --unrolled # for recurrent cells on PTB
Note the validation performance in this step does not indicate the final performance of the architecture. One must train the obtained genotype/architecture from scratch using full-sized models, as described in the next section.
Snapshots of the most likely normal conv, reduction conv, and recurrent cells during arch search.
To evaluate our best cells by training from scratch, run
cd cnn && python train.py --auxiliary --cutout # CIFAR-10
cd rnn && python train.py # PTB
cd rnn && python train.py --data ../data/wikitext-2 \ # WT2
--dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary # ImageNet
Customized architectures are supported through the --arch
flag once specified in genotypes.py
.
The CIFAR-10 result at the end of training is subject to variance due to the non-determinism of cuDNN back-prop kernels. It would be misleading to report the result of only a single run. By training our best cell from scratch, one should expect the average test accuracy of 10 independent runs to fall in the range of 2.76 +/- 0.09% with a high probability.
Expected learning curves on CIFAR-10 (4 runs), ImageNet and PTB.
Package graphviz is required to visualize the learned cells
python visualize.py DARTS
where DARTS
can be replaced by any customized architectures in genotypes.py
.
If you use any part of this code in your research, please cite our paper:
@article{liu2018darts,
title={DARTS: Differentiable Architecture Search},
author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
journal={arXiv preprint arXiv:1806.09055},
year={2018}
}