Skip to content

Commit

Permalink
Merge branch 'feature/prompt_tuning' of github.com:yh351016/OFA into …
Browse files Browse the repository at this point in the history
…feature/prompt_tuning
  • Loading branch information
TheMrYang committed Aug 24, 2022
2 parents ab893b2 + 54494c2 commit 4e3dead
Show file tree
Hide file tree
Showing 20 changed files with 42,542 additions and 83 deletions.
58 changes: 19 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,11 @@ This source code is licensed under the Apache 2.0 license found in the LICENSE f
<br>
<p>
<br>

<p align="center">
<a href="https://github.com/huggingface/transformers/blob/master/LICENSE">
<img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue">
</a>
<a href="https://huggingface.co/ofa-sys">
<img alt="spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue">
</a>
<a href="colab.md"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="DOI"></a>
<a href="checkpoints.md">Checkpoints</a>&nbsp | &nbsp<a href="colab.md">Colab</a>&nbsp | &nbsp<a href="https://huggingface.co/ofa-sys">Demo</a>&nbsp | &nbsp<a href="http://arxiv.org/abs/2202.03052">Paper </a>&nbsp | &nbspBlog
</p>

<h4 align="center">
<p>
<a href="http://arxiv.org/abs/2202.03052">Paper</a> |
<b>Blog</b>
<p>
</h4>
<br></br>

<p align="center">
<br>
<img src="examples/demo.gif" width="800" />
Expand All @@ -36,33 +23,38 @@ This source code is licensed under the Apache 2.0 license found in the LICENSE f

[colab]: <https://colab.research.google.com/assets/colab-badge.svg>

OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks
(e.g., image generation, visual grounding, image captioning, image classification, text generation, etc.)
to a simple sequence-to-sequence learning framework. For more information, please refer to our paper: [OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework](http://arxiv.org/abs/2202.03052).

In the following, we provide:
* News about our recent updates;
* Online Demos with links to Huggingface spaces and Colab notebooks;
* Model card (including official release of pretrained checkpoints (more can be found at [checkpoints.md](checkpoints.md)), and we also provide checkpoints for Huggingface Transformers on [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys)) and experimental results of OFA models of different sizes;
* Step-by-step instructions of pretraining and finetuning (including almost all tasks presented in the paper);
* Case demonstration of OFA.
OFA is a unified sequence-to-sequence pretrained model (support **English** and **Chinese**) that unifies modalities (i.e., cross-modality, vision, language) and tasks (**finetuning** and **prompt tuning** are supported): image captioning (1st at the [MSCOCO Leaderboard](https://competitions.codalab.org/competitions/3221#results)), VQA ([link](https://eval.ai/web/challenges/challenge-page/830/leaderboard/2278)), visual grounding, text-to-image generation, text classification, text generation, image classification, etc. We provide **step-by-step** instructions for pretraining and finetuning and corresponding checkpoints (check official ckpt \[[EN](checkpoints.md)|[CN](checkpoints_cn.md)\] or [huggingface ckpt](https://huggingface.co/OFA-Sys)).

We sincerely welcome contributions to our project. Feel free to contact us or send us issues / PRs!
<br></br>


# Online Demos
We provide online demo via Hugging Face Spaces for you to interact with our pretrained and finetuned models. Below are the links to the demos:
* [Image Captioning](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)
* [Visual Grounding](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Grounding)
* [Visual Question Answering](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Question_Answering)
* [Text-to-Image Generation](https://huggingface.co/spaces/OFA-Sys/OFA-Text2Image_Generation)
* [Generic Interface](https://huggingface.co/spaces/OFA-Sys/OFA-Generic_Interface)

Also we provide Colab notebooks for you to better perceive the procedures. Click [here](colab.md) to check them out!
<br></br>


# News
* 2022.8.5: Released support of prompt tuning for OFA (temporarily maintained at `feature/prompt_tuning`). Check our paper [here](https://arxiv.org/abs/2208.02532)! Please see the [prompt_tuning.md](prompt_tuning.md) for further details.
* 2022.8.16: Released the **Chinese** version of OFA. **OFA-CN** needs only switching to `bpe_dir=../../utils/BERT_CN_dict` and using our provided Chinese checkpoints in [checkpoints_cn.md](checkpoints_cn.md). Temporarily, we only provide base-size and large-size pretrained checkpoints and finetuned checkpoints on [MUGE Caption](https://tianchi.aliyun.com/muge) and the Chinese version of RefCOCO(-/+/g) (to release soon).
* 2022.8.5: Released support of **prompt tuning** for OFA. Check our paper [here](https://arxiv.org/abs/2208.02532)! Please see the [prompt_tuning.md](prompt_tuning.md) for further details.
* 2022.7.7: Updated support of OFA on **huggingface transformers** (fixed bugs in forward, add sequence generator from Fairseq to ensure performance, etc.). Refer to the doc [transformers.md](transformers.md) and the branch `feature/add_transformers`.
* 2022.6.17: Released the pretrained checkpoint of **OFA-Huge**. To use it, set `--arch=ofa_huge` in the script.
* 2022.5.15: OFA was accepted by **ICML 2022**
* 2022.4.28: Add support of inference on **huggingface transformers**. For how to use it, please refer to the doc [transformers.md](transformers.md) and our [huggingface models](https://huggingface.co/OFA-Sys).
* 2022.4.16: Released lightweight pretrained models **OFA-Medium** (~93M params) and **OFA-Tiny** (~33M params) in [checkpoints.md](checkpoints.md). To use them, you just need to load the corresponding checkpoint and set `--arch=ofa_medium` or `--arch=ofa_tiny` in the scripts.
* 2022.3.23: Added [Encouraging Loss](https://arxiv.org/pdf/2110.06537.pdf) as a feature. See [README_EncouragingLoss.md](README_EncouragingLoss.md). Leveraging this feature, OFA-Large has achieved improved results in both VQA (**test-std acc: 80.67**) and Image Classification (**test acc: 85.6**) recently.

<details>
<summary><b>More News</b></summary>
<p>
<ul>
<li>2022.3.23: Added [Encouraging Loss](https://arxiv.org/pdf/2110.06537.pdf) as a feature. See [README_EncouragingLoss.md](README_EncouragingLoss.md). Leveraging this feature, OFA-Large has achieved improved results in both VQA (**test-std acc: 80.67**) and Image Classification (**test acc: 85.6**) recently.</li>
<li>2022.3.21: Released codes for pretraining OFA.</li>
<li>2022.3.18: Released the finetuned <b>OFA-Base</b> (~180M parameters) checkpoints and running scripts for vision & language tasks, including: <b>Caption (146.4 CIDEr), VQA (78.07 on test-std), SNLI-VE (89.3 on dev), RefCOCO (90.67 on testA), RefCOCO+ (87.15 on testA) and RefCOCOg (82.31 on test-u)</b>.</li>
<li>2022.3.11: Released the finetuning & inference code/checkpoints for <b>Gigaword</b>.</li>
Expand All @@ -79,18 +71,6 @@ We sincerely welcome contributions to our project. Feel free to contact us or se
<br></br>


# Online Demos
We provide online demo via Hugging Face Spaces for you to interact with our pretrained and finetuned models. Below are the links to the demos:
* [Generic Interface](https://huggingface.co/spaces/OFA-Sys/OFA-Generic_Interface)
* [Image Captioning](https://huggingface.co/spaces/OFA-Sys/OFA-Image_Caption)
* [Text-to-Image Generation](https://huggingface.co/spaces/OFA-Sys/OFA-Text2Image_Generation)
* [Visual Grounding](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Grounding)
* [Visual Question Answering](https://huggingface.co/spaces/OFA-Sys/OFA-Visual_Question_Answering)

Also we provide Colab notebooks for you to better perceive the procedures. Click [here](colab.md) to check them out!
<br></br>


# Model Card
We list the parameters and pretrained checkpoints of OFAs below. For finetuned checkpoints, please refer to [checkpoints.md](checkpoints.md).

Expand Down
2 changes: 1 addition & 1 deletion checkpoints.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Checkpoints

We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks.
We provide links for you to download our checkpoints, including pretrained and finetuned models on different tasks. If you would like to use OFA with Transformers, please download checkpoints at [https://huggingface.co/OFA-Sys](https://huggingface.co/OFA-Sys), and check the code in the branch `feature/add_transformers`.

## Pretraining
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_huge.pt"> Pre-trained checkpoint (OFA-Huge) </a> (~930M parameters)
Expand Down
82 changes: 82 additions & 0 deletions checkpoints_cn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Checkpoints (OFA-CN)

We provide checkpoints of OFA-CN, which is the Chinese version of OFA. We provide Base-size and Large-size models, including pretrained and finetuned models on image captioning and referring expression comprehension. Note that we translated the texts in the RefCOCO(-/+/g) datasets and finetuned OFA-CN on them. We plan to release the related new datasets in the near future.
<br>

## Checkpoints
Below we provide the links for downloading the Chinese OFA checkpoints.

### Pretraining
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_large.pt"> Pretrained checkpoint (OFA-CN-Large) </a> (~443M parameters)
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_cn_base.pt "> Pretrained checkpoint (OFA-CN-Base) </a> (~160M parameters)

### Finetuning (OFA-Large)
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_large.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_large.pt"> Finetuned checkpoint for RefCOCO-CN </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_large.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_large.pt"> Finetuned checkpoint for RefCOCOg-CN </a>

### Finetuning (OFA-Base)
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_cn_base.pt"> Finetuned checkpoint for MUGE Caption (Stage 1) </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_cn_base.pt"> Finetuned checkpoint for RefCOCO-CN </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_cn_base.pt"> Finetuned checkpoint for RefCOCO+-CN </a>
* <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_cn_base.pt"> Finetuned checkpoint for RefCOCOg-CN </a>
<br>

## Model Card
Below we provide the basic information of the base-size and large-size OFA-CN.

<table border="1" width="100%">
<tr align="center">
<th>Model</th><th>#Params</th><th>Backbone</th><th>Hidden Size</th><th>Intermediate Size</th><th>#Heads</th><th>#Enc. Layers</th><th>#Dec. Layers</th>
</tr>
<tr align="center">
<td>OFA<sub>Base</sub><td>160M</td><td>ResNet101</td><td>768</td></td><td>3072</td><td>12</td><td>6</td><td>6</td>
</tr>
<tr align="center">
<td>OFA<sub>Large</sub></td><td>443M</td><td>ResNet152</td><td>1024</td></td><td>4096</td><td>16</td><td>12</td><td>12</td>
</tr>
</tr>
</table>
<br>

## Results
Below we provide the results of OFA-CN and the baselines for comparison.

### [MUGE Caption]("https://tianchi.aliyun.com/muge")
<table border="1" width="100%">
<tr align="center">
<td>模型</td><td>BLEU@4</td><td>ROUGE-L</td><td>CIDEr-D</td>
</tr>
<tr align="center">
<td>Trm </td><td>7.33</td><td>51.51</td><td>11.00</td>
</tr>
<tr align="center">
<td>M6</td><td>16.19</td><td>55.06</td><td>30.75</td>
</tr>
<tr align="center">
<td>OFA<sub>Base</sub></td><td>26.23</td><td>58.95</td><td>50.70</td>
</tr>
<tr align="center">
<td>OFA<sub>Large</sub></td><td><b>27.32</b></td><td><b>59.20</b></td><td><b>53.51</b></td>
</tr>
</table>

### RefCOCO-CN Series
<table border="1" width="100%">
<tr align="center">
<td>Model</td><td>RefCOCO(val/testA/testB)</td><td>RefCOCO+(val/testA/testB)</td><td>RefCOCOg(val/test-u)</td>
</tr>
<tr align="center">
<td>OFA<sub>Base</sub>(random-init)</td><td>30.13/35.07/25.03</td><td>17.89/20.90/15.83</td><td>20.30/20.45</td>
</tr>
<tr align="center">
<td>OFA<sub>Base</sub></td><td>82.18/86.07/<b>76.68</b></td><td>69.38/77.26/60.14</td><td><b>73.57/72.53</b></td>
</tr>
<tr align="center">
<td>OFA<sub>Large</sub></td><td><b>82.84/86.54</b>/76.50</td><td><b>71.30/78.56/61.85</b></td><td>71.96/71.30</td>
</tr>
</table>
<br>


7 changes: 6 additions & 1 deletion data/mm_data/caption_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def __init__(
transforms.Normalize(mean=mean, std=std),
])

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = " what does the image describe?"
elif type(bpe).__name__ == 'BertBPE':
self.prompt = "图片描述了什么内容?"

def __getitem__(self, index):
uniq_id, image, caption = self.dataset[index]

Expand All @@ -128,7 +133,7 @@ def __getitem__(self, index):
caption = ' '.join(caption.strip().split())
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
tgt_caption = '&&'.join(caption_list)
src_item = self.encode_text(" what does the image describe?")
src_item = self.encode_text(self.prompt)
tgt_item = self.encode_text(" {}".format(tgt_caption))

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
7 changes: 6 additions & 1 deletion data/mm_data/refcoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def __init__(
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
])

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = ' which region does the text " {} " describe?'
elif type(bpe).__name__ == 'BertBPE':
self.prompt = '这段文字" {} "描述的是哪个区域?'

def __getitem__(self, index):
uniq_id, base64_str, text, region_coord = self.dataset[index]

Expand All @@ -139,7 +144,7 @@ def __getitem__(self, index):
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
src_caption = self.pre_caption(text, self.max_src_length)
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
src_item = self.encode_text(self.prompt.format(src_caption))
tgt_item = self.encode_text(region_coord, use_bpe=False)

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
9 changes: 7 additions & 2 deletions data/nlg_data/summary_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def __init__(
self.num_bins = num_bins
self.noise_ratio = noise_ratio

if type(bpe).__name__ == 'GPT2BPE':
self.prompt = ' what is the summary of article " {} "?'
elif type(bpe).__name__ == 'BertBPE':
self.prompt = "{} 请用一个句子简单总结上文:"

def __getitem__(self, index):
source, target = self.dataset[index]
target_str = target.lower()
Expand All @@ -91,10 +96,10 @@ def __getitem__(self, index):
target = target.replace('<unk>', 'unk')

src_item = self.encode_text(
' what is the summary of article " {} "?'.format(source),
self.prompt.format(source),
length=self.max_src_length
)
tgt_item = self.encode_text(' {}'.format(target))
tgt_item = self.encode_text('{}'.format(target))
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio)

src_item = torch.cat([self.bos_item, src_item, self.eos_item])
Expand Down
8 changes: 4 additions & 4 deletions data/ofa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def encode_text(self, text, length=None, append_bos=False, append_eos=False, use
s = torch.cat([s, self.eos_item])
return s

def pre_question(self, question, max_ques_words):
def pre_question(self, question, max_ques_words=None):
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')

question = re.sub(
Expand All @@ -55,12 +55,12 @@ def pre_question(self, question, max_ques_words):

# truncate question
question_words = question.split(' ')
if len(question_words) > max_ques_words:
if max_ques_words is not None and len(question_words) > max_ques_words:
question = ' '.join(question_words[:max_ques_words])

return question

def pre_caption(self, caption, max_words):
def pre_caption(self, caption, max_words=None):
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')

caption = re.sub(
Expand All @@ -73,7 +73,7 @@ def pre_caption(self, caption, max_words):

# truncate caption
caption_words = caption.split(' ')
if len(caption_words) > max_words:
if max_words is not None and len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])

return caption
15 changes: 7 additions & 8 deletions data/pretrain_data/unify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,11 +627,10 @@ def collater(self, samples, pad_to_length=None):
for sample_tuple in samples:
samples_v1 += sample_tuple[0]
samples_v2 += sample_tuple[1]
if samples_v2 == []:
samples_v2 += self.process_pure_text(0) if self.pure_text_dataset else []
samples_v2 += self.process_pure_image(0) if self.pure_image_dataset else []
samples_v2 += self.process_detection(0) if self.detection_dataset else []

res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1, res_v2
if samples_v2 != []:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
res_v2 = collate(samples_v2, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1, res_v2
else:
res_v1 = collate(samples_v1, pad_idx=self.src_dict.pad(), eos_idx=self.eos)
return res_v1
15 changes: 15 additions & 0 deletions examples/OFA_logo_tp_path.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/open_vqa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion models/ofa/ofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def ofa_large_architecture(args):
args.scale_heads = getattr(args, "scale_heads", False)
args.scale_resids = getattr(args, "scale_resids", False)

args.orig_patch_image_size = getattr(args, "orig_patch_image_size", 256)


@register_model_architecture("ofa", "ofa_base")
def ofa_base_architecture(args):
Expand Down Expand Up @@ -423,7 +425,7 @@ def ofa_medium_architecture(args):


@register_model_architecture("ofa", "ofa_tiny")
def ofa_medium_architecture(args):
def ofa_tiny_architecture(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 256)
args.encoder_layers = getattr(args, "encoder_layers", 4)
Expand Down
Loading

0 comments on commit 4e3dead

Please sign in to comment.