Skip to content

Latest commit

 

History

History
83 lines (68 loc) · 4.08 KB

README_CN.md

File metadata and controls

83 lines (68 loc) · 4.08 KB

PureT

论文 End-to-End Transformer Based Model for Image Captioning [PDF/AAAI] [PDF/Arxiv] 的PyTorch实现 [AAAI 2022]

architecture

环境要求 (Our Main Enviroment)

  • Python 3.7.4
  • PyTorch 1.5.1
  • TorchVision 0.6.0
  • coco-caption
  • numpy
  • tqdm

预处理

1. coco-caption

参考coco-caption的 README.md, 主要是需要下载SPICE指标需要使用的Stanford CoreNLP 3.6.0代码和模型。 直接使用脚本下载即可:

cd coco_caption
bash get_stanford_models.sh

2. 数据准备

训练和验证过程所需要的重要数据都存储在 mscoco 路径下,文件夹组织结构如下:

mscoco/
|--feature/
    |--coco2014/
       |--train2014/
       |--val2014/
       |--test2014/
       |--annotations/
|--misc/
|--sent/
|--txt/

MSCOCO 2014 数据集的所有源图像和注释文件置于mscoco/feature/coco2014路径下。其余文件可通过GoogleDrive或者百度网盘(提取码: hryh)进行下载。

注意: 为了进一步加快训练速度,也可以将数据集中所有图像的特征提取出来并保存为npz文件,可以在mscoco/feature路径下新建目录存储特征文件,训练和验证时需要将数据集读取改为coco_dataset.pydata_loader.py中的方式。同时也需要修改pure_transformer.py(主要就是删除掉Backbone模块的定义,其余类和函数的接口应该是通用的)。

模型训练

注意: 代码实现主要基于JDAI-CV/image-captioning,直接复用了他们的配置文件没做太多修改,所以里面会有一些对我们模型无用的超参数设置。(需要进一步整理删除)

1. XE损失下训练

首先从GoogleDrive或者百度网盘(提取码: hryh)中下载Backbone(Swin-Transformer)的预训练模型,并将其存储在当前工程的根目录下。(预训练模型来源于SwinTransformer官方开源库,去除了其中的head模块部分的权重)

在训练前,可能还需要检查和修改config.ymltrain.sh文件以适应你的运行环境。然后直接开训:

# for XE training
bash experiments_PureT/PureT_XE/train.sh

2. SCST训练

将XE训练后相对较好的模型复制并存储于experiments_PureT/PureT_SCST/snapshot/中。然后继续训练:

# for SCST training
bash experiments_PureT/PureT_SCST/train.sh

模型测试

可以直接从GoogleDrive百度网盘(提取码: hryh)下载论文中报告结果所对应的预训练模型。

CUDA_VISIBLE_DEVICES=0 python main_test.py --folder experiments_PureT/PureT_SCST/ --resume 27
BLEU-1 BLEU-2 BLEU-3 BLEU-4 METEOR ROUGE-L CIDEr SPICE
82.1 67.3 52.0 40.9 30.2 60.1 138.2 24.2

Reference

If you find this repo useful, please consider citing (no obligation at all):

@inproceedings{wangyiyu2022PureT,
  title={End-to-End Transformer Based Model for Image Captioning},
  author={Yiyu Wang and Jungang Xu and Yingfei Sun},
  booktitle={AAAI},
  year={2022}
}

Acknowledgements

This repository is based on JDAI-CV/image-captioning, ruotianluo/self-critical.pytorch and microsoft/Swin-Transformer.