Skip to content

Commit

Permalink
[T5] T5 in ParlAI (facebookresearch#3519)
Browse files Browse the repository at this point in the history
* t5

* lint

* slight cleanup

* overfit task

* try catch

* install transformers for cpu

* always install transformers

* migrate t5 to hf

* black

* import again

* change to import error

* set device

* bump tf versions

* get version right

* handle gpt2

* add tear down

* change vectorize

* update comment for vectorize; move version check to

* reduce bsz for dialogpt

* address spencer comments
  • Loading branch information
klshuster authored Mar 19, 2021
1 parent 6e5218f commit a8fe17c
Show file tree
Hide file tree
Showing 8 changed files with 798 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ commands:
python -m pip install --progress-bar off 'fairscale~=0.3.0'
python -m pip install --progress-bar off 'torchtext==0.7.0'
python -m pip install --progress-bar off pytorch-pretrained-bert
python -m pip install --progress-bar off 'transformers<4.0.0'
python -m pip install --progress-bar off 'transformers==4.3.3'
python -m pip install --progress-bar off 'fairseq==0.10.0'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
Expand All @@ -115,7 +115,7 @@ commands:
python -m pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --progress-bar off 'torchtext==0.7.0'
python -m pip install --progress-bar off pytorch-pretrained-bert
python -m pip install --progress-bar off 'transformers<4.0.0'
python -m pip install --progress-bar off 'transformers==4.3.3'
python -m pip install --progress-bar off 'fairseq==0.10.0'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
Expand Down
29 changes: 28 additions & 1 deletion parlai/agents/hugging_face/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,31 @@ Enter Your Message: What do you think of parrots?
```bash
parlai train_model -m hugging_face/dialogpt --add-special-tokens True --delimiter '\n' --add-start-token True --gpt2-size medium -t convai2 -bs 2 -mf <modelfile>
```
_Note:_ In the above command, we change the default delimiter from `--delimiter '<|endoftext|>'`, as a personal choice.
_Note:_ In the above command, we change the default delimiter from `--delimiter '<|endoftext|>'`, as a personal choice.


## T5

"Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

See https://arxiv.org/abs/1910.10683.


### Implementation

The T5 model in ParlAI is based on the `T5ForConditionalGeneration` provided by the [HuggingFace Transformers](https://github.com/huggingface/transformers) library. The model can be instantiated with any of the provided architectures there:

- `t5-small`: 60 million parameters
- `t5-base`: 220 million parameters
- `t5-large`: 770 million parameters
- `t5-3b`: 3 billion parameters
- `t5-11b`: 11 billion parameters

**Model Parallel**: HuggingFace has implemented model parallel for T5, however it is an experimental feature, so proceed at your own risk; you can use model parallel by simply specifying `--t5-model-parallel`.

### Basic Examples

#### Train t5 large on convai2.
```bash
parlai train_model -m hugging_face/t5 -mf /tmp/model_file -t convai2 -bs 24 --fp16 true -eps 1 -lr 1e-5 --optimizer adam --t5-model-arch t5-large
```
137 changes: 113 additions & 24 deletions parlai/agents/hugging_face/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

import os

from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty
from collections import defaultdict
from typing import List

from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt
from parlai.utils.io import PathManager


try:
from transformers import GPT2Tokenizer
from transformers import GPT2Tokenizer, T5TokenizerFast
except ImportError:
raise ImportError(
"Need to install Hugging Face transformers repository. "
Expand All @@ -31,15 +33,30 @@ class HuggingFaceDictionaryAgent(DictionaryAgent, ABC):
Use Hugging Face tokenizers.
"""

def __init__(self, opt):
super().__init__(opt)
# initialize from vocab path
self.tokenizer = self.get_tokenizer(opt)
def __init__(self, opt: Opt, shared=None):
if not shared:
self.hf_tokenizer = self.get_tokenizer(opt)
self.tok2ind = self.hf_tokenizer.get_vocab()
self.ind2tok = {v: k for k, v in self.tok2ind.items()}
else:
self.hf_tokenizer = shared['hf_tokenizer']
self.tok2ind = shared['tok2ind']
self.ind2tok = shared['ind2tok']

self.freq = defaultdict(int)
for tok in self.tok2ind:
self.freq[tok] = 1
self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)

self._unk_token_idx = self.hf_tokenizer.unk_token_id
self.override_special_tokens(opt)
for i in range(self.tokenizer.vocab_size):
token = self.tokenizer._convert_id_to_token(i)
self.add_token(token)
self.freq[token] = 1

self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
self.tokenizer = 'hf'
self.opt = opt
self.max_length = (
self.opt.get('text_truncate') or self.hf_tokenizer.model_max_length
)

@abstractmethod
def get_tokenizer(self, opt):
Expand All @@ -55,20 +72,46 @@ def override_special_tokens(opt):
"""
pass

@abstractproperty
def add_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""

@abstractproperty
def skip_decode_special_tokens(self) -> bool:
"""
Whether to skip special tokens when converting tokens to text.
"""

def share(self):
shared = super().share()
shared['hf_tokenizer'] = self.hf_tokenizer
shared['ind2tok'] = self.ind2tok
shared['tok2ind'] = self.tok2ind
return shared

def format_text(self, text: str) -> str:
"""
Format text prior to encoding with tokenizer.
"""
return text

def txt2vec(self, text, vec_type=list):
tokens = self.tokenizer.tokenize(text)
tokens_id = self.tokenizer.convert_tokens_to_ids(tokens)
return tokens_id
return self.hf_tokenizer.encode(
self.format_text(text),
add_special_tokens=self.add_special_tokens,
max_length=self.max_length,
pad_to_max_length=False,
truncation='longest_first',
)

def vec2txt(self, vec):
return self.tokenizer.decode(
vec, skip_special_tokens=False, clean_up_tokenization_spaces=True
def vec2txt(self, vec, **kwargs):
return self.hf_tokenizer.decode(
vec, skip_special_tokens=self.skip_decode_special_tokens, **kwargs
)

def act(self):
"""
Dummy override.
"""
return {}


Expand All @@ -79,6 +122,20 @@ def is_prebuilt(self):
"""
return True

@property
def add_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""
return True

@property
def skip_decode_special_tokens(self) -> bool:
"""
Whether to skip special tokens when converting tokens to text.
"""
return False

def get_tokenizer(self, opt):
"""
Instantiate tokenizer.
Expand Down Expand Up @@ -107,7 +164,7 @@ def add_additional_special_tokens(self, additional_special_tokens: List[str]):
Add additional special tokens to the dictionary.
"""
self.additional_special_tokens = additional_special_tokens
self.tokenizer.add_special_tokens(
self.hf_tokenizer.add_special_tokens(
{'additional_special_tokens': additional_special_tokens}
)
for tok in self.additional_special_tokens:
Expand All @@ -116,7 +173,7 @@ def add_additional_special_tokens(self, additional_special_tokens: List[str]):
def _define_special_tokens(self, opt):
if opt["add_special_tokens"]:
# Add addtional start/end/pad tokens
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.hf_tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.start_token = SPECIAL_TOKENS["bos_token"]
self.end_token = SPECIAL_TOKENS["eos_token"]
self.null_token = SPECIAL_TOKENS["pad_token"]
Expand All @@ -130,9 +187,9 @@ def override_special_tokens(self, opt):
# define special tokens
self._define_special_tokens(opt)
# now override
self.start_idx = self.tokenizer.convert_tokens_to_ids([self.start_token])[0]
self.end_idx = self.tokenizer.convert_tokens_to_ids([self.end_token])[0]
self.null_idx = self.tokenizer.convert_tokens_to_ids([self.null_token])[0]
self.start_idx = self.hf_tokenizer.convert_tokens_to_ids([self.start_token])[0]
self.end_idx = self.hf_tokenizer.convert_tokens_to_ids([self.end_token])[0]
self.null_idx = self.hf_tokenizer.convert_tokens_to_ids([self.null_token])[0]
# set tok2ind for special tokens
self.tok2ind[self.end_token] = self.end_idx
self.tok2ind[self.start_token] = self.start_idx
Expand All @@ -151,3 +208,35 @@ def get_tokenizer(self, opt):
model_sz = opt["gpt2_size"]
fle_key = f"microsoft/DialoGPT-{model_sz}"
return GPT2Tokenizer.from_pretrained(fle_key)


class T5DictionaryAgent(HuggingFaceDictionaryAgent):
def get_tokenizer(self, opt):
return T5TokenizerFast.from_pretrained(opt['t5_model_arch'], truncation=True)

@property
def add_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""
return True

@property
def skip_decode_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""
return True

def override_special_tokens(self, opt):
# now override
self.start_token = self.hf_tokenizer.pad_token
self.end_token = self.hf_tokenizer.eos_token
self.null_token = self.hf_tokenizer.pad_token
self.unk_token = self.hf_tokenizer.unk_token

self._unk_token_idx = self.hf_tokenizer.unk_token_id

self.start_idx = self[self.start_token]
self.end_idx = self[self.end_token]
self.null_idx = self[self.null_token]
10 changes: 8 additions & 2 deletions parlai/agents/hugging_face/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, opt, dict):
# add special tokens
if opt["add_special_tokens"]:
size_before = self.transformer.wte.weight.size(0)
self.transformer.resize_token_embeddings(len(dict.tokenizer))
self.transformer.resize_token_embeddings(len(dict.hf_tokenizer))
with torch.no_grad():
# first reduce the random jitter of the initialization
self.transformer.wte.weight[size_before:] *= 0.1
Expand Down Expand Up @@ -185,7 +185,13 @@ def output(self, tensor):
def reorder_decoder_incremental_state(self, incremental_state, inds):
new_incr_state = []
for layer_past in incremental_state:
new_incr_state.append(torch.index_select(layer_past, 1, inds))
if torch.is_tensor(layer_past):
new_incr_state.append(torch.index_select(layer_past, 1, inds))
else:
# newer versions of HF split up the intermediate outputs
assert isinstance(layer_past, tuple)
layer_past = torch.stack(layer_past, dim=0)
new_incr_state.append(torch.index_select(layer_past, 1, inds))

return tuple(new_incr_state)

Expand Down
8 changes: 6 additions & 2 deletions parlai/agents/hugging_face/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
hugging_face/dialogpt`.
"""
try:
import transformers # noqa: F401
import transformers
except ImportError:
raise ImportError('Please run `pip install transformers`.')


HF_VERSION = float('.'.join(transformers.__version__.split('.')[:2]))


class HuggingFaceAgent:
def __init__(self, opt, shared=None):
raise RuntimeError(
'`-m hugging_face` is not a valid choice. Please run with '
'`-m hugging_face/gpt2` or `-m hugging_face/dialogpt`.'
'`-m hugging_face/gpt2`, `-m hugging_face/dialogpt`, '
'or `-m hugging_face/t5`'
)
Loading

0 comments on commit a8fe17c

Please sign in to comment.