Skip to content

Commit

Permalink
[Whisper] Make tokenizer normalization public (huggingface#28136)
Browse files Browse the repository at this point in the history
* [Whisper] Make tokenizer normalization public

* add to docs
  • Loading branch information
sanchit-gandhi authored Jan 29, 2024
1 parent e694e98 commit da3c79b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ python convert_hf_to_openai.py \
- save_vocabulary
- batch_decode
- decode
- basic_normalize
- normalize

## WhisperTokenizerFast

Expand All @@ -113,6 +115,8 @@ python convert_hf_to_openai.py \
- save_vocabulary
- batch_decode
- decode
- basic_normalize
- normalize

## WhisperFeatureExtractor

Expand Down
21 changes: 18 additions & 3 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
import warnings
from functools import lru_cache
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -507,6 +508,20 @@ def _convert_id_to_token(self, index):
return self.decoder.get(index, "")

def _normalize(self, text):
warnings.warn(
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper English normalizer using the `normalize` method."
)
return self.normalize(text)

def _basic_normalize(self, text, remove_diacritics=False):
warnings.warn(
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
)
return self.basic_normalize(text, remove_diacritics=remove_diacritics)

def normalize(self, text):
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
Expand All @@ -515,7 +530,7 @@ def _normalize(self, text):
return normalizer(text)

@staticmethod
def _basic_normalize(text, remove_diacritics=False):
def basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
Expand Down Expand Up @@ -745,10 +760,10 @@ def _decode(
text = "".join(sub_texts)

if normalize:
clean_text = self._normalize(text)
clean_text = self.normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import os
import re
import warnings
from functools import lru_cache
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -427,6 +428,22 @@ def _decode(

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._normalize
def _normalize(self, text):
warnings.warn(
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper English normalizer using the `normalize` method."
)
return self.normalize(text)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def _basic_normalize(self, text, remove_diacritics=False):
warnings.warn(
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
)
return self.basic_normalize(text, remove_diacritics=remove_diacritics)

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.normalize
def normalize(self, text):
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
Expand All @@ -435,8 +452,8 @@ def _normalize(self, text):
return normalizer(text)

@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def _basic_normalize(text, remove_diacritics=False):
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.basic_normalize
def basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
Expand Down

0 comments on commit da3c79b

Please sign in to comment.