This repo is a PyTorch implementation of applying VAN (Visual Attention Network) to semantic segmentation. The code is based on mmsegmentaion.
More details can be found in Visual Attention Network.
@article{guo2022visual,
title={Visual Attention Network},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2202.09741},
year={2022}
}
@inproceedings{
ham,
title={Is Attention Better Than Matrix Decomposition?},
author={Zhengyang Geng and Meng-Hao Guo and Hongxu Chen and Xia Li and Ke Wei and Zhouchen Lin},
booktitle={International Conference on Learning Representations},
year={2021},
}
Notes: Pre-trained models can be found in Visual Attention Network for Classification.
Method | Backbone | Iters | mIoU | Params | FLOPs | Config | Download |
---|---|---|---|---|---|---|---|
Light-Ham-D256 | VAN-Tiny | 160K | 40.9 | 4.2M | 6.5G | config | Google Drive |
- | - | - | - | - | - | - | - |
HamNet | VAN-Tiny-OS8 | 160K | 41.5 | 11.9M | 50.8G | config | Google Drive |
HamNet | VAN-Small-OS8 | 160K | 45.1 | 24.2M | 100.6G | config | Google Drive |
HamNet | VAN-Base-OS8 | 160K | 48.7 | 36.9M | 153.6G | config | Google Drive |
HamNet | VAN-Large-OS8 | 160K | 50.2 | 55.1M | 227.7G | config | Google Drive |
- | - | - | - | - | - | - | - |
UperNet | VAN-Tiny | 160K | 41.1 | 32.1M | 214.7 | config | |
UperNet | VAN-Small | 160K | 44.9 | 43.8M | 224.0G | config | |
UperNet | VAN-Base | 160K | 48.3 | 56.6M | 237.1G | config | |
UperNet | VAN-Large | 160K | 50.1 | 74.7M | 257.7G | config |
Notes: In this scheme, we use multi-scale validation following Swin-Transformer. FLOPs are tested under the input size of 512
Backbone | Iters | mIoU | Config | Download |
---|---|---|---|---|
VAN-Tiny | 40K | 38.5 | config | Google Drive |
VAN-Small | 40K | 42.9 | config | Google Drive |
VAN-Base | 40K | 46.7 | config | Google Drive |
VAN-Large | 40K | 48.1 | config | Google Drive |
Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation.
Pytorch >= 1.7
MMSegmentation == v0.12.0 (https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0)
We use 8 GPUs for training by default. Run:
dist_train.sh /path/to/config 8
To evaluate the model, run:
dist_test.sh /path/to/config /path/to/checkpoint_file 8 --out results.pkl --eval mIoU
Install torchprofile using
pip install torchprofile
To calculate FLOPs for a model, run:
bash tools/flops.sh /path/to/checkpoint_file --shape 512 512
Our implementation is mainly based on mmsegmentaion, Swin-Transformer, PoolFormer, and Enjoy-Hamburger. Thanks for their authors.
This repo is under the Apache-2.0 license. For commercial use, please contact the authors.