Skip to content

Commit

Permalink
Make the sacremoses dependency optional (huggingface#17049)
Browse files Browse the repository at this point in the history
* Make sacremoses optional

* Pickle
  • Loading branch information
LysandreJik authored May 2, 2022
1 parent bb2e088 commit 30ca529
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 12 deletions.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def run(self):
"nltk",
"GitPython",
"hf-doc-builder",
'sacremoses'
)
+ extras["retrieval"]
+ extras["modelcreation"]
Expand Down Expand Up @@ -365,7 +366,6 @@ def run(self):
"protobuf",
"regex",
"requests",
"sacremoses",
"sentencepiece",
"torch",
"tokenizers",
Expand All @@ -383,7 +383,6 @@ def run(self):
deps["pyyaml"], # used for the model cards metadata
deps["regex"], # for OpenAI GPT
deps["requests"], # for downloading models over HTTPS
deps["sacremoses"], # for XLM
deps["tokenizers"],
deps["tqdm"], # progress bars in model download and training scripts
]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# order specific notes:
# - tqdm must be checked before tokenizers

pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
if sys.version_info < (3, 7):
pkgs_to_check_at_runtime.append("dataclasses")
if sys.version_info < (3, 8):
Expand Down
36 changes: 31 additions & 5 deletions src/transformers/models/fsmt/tokenization_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import unicodedata
from typing import Dict, List, Optional, Tuple

import sacremoses as sm

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

Expand Down Expand Up @@ -212,6 +210,16 @@ def __init__(
**kwargs,
)

try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)

self.sm = sacremoses

self.src_vocab_file = src_vocab_file
self.tgt_vocab_file = tgt_vocab_file
self.merges_file = merges_file
Expand Down Expand Up @@ -254,21 +262,21 @@ def vocab_size(self) -> int:

def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
return self.cache_moses_punct_normalizer[lang].normalize(text)

def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
return self.cache_moses_tokenizer[lang].tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=True
)

def moses_detokenize(self, tokens, lang):
if lang not in self.cache_moses_tokenizer:
moses_detokenizer = sm.MosesDetokenizer(lang=self.tgt_lang)
moses_detokenizer = self.sm.MosesDetokenizer(lang=self.tgt_lang)
self.cache_moses_detokenizer[lang] = moses_detokenizer
return self.cache_moses_detokenizer[lang].detokenize(tokens)

Expand Down Expand Up @@ -516,3 +524,21 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1

return src_vocab_file, tgt_vocab_file, merges_file

def __getstate__(self):
state = self.__dict__.copy()
state["sm"] = None
return state

def __setstate__(self, d):
self.__dict__ = d

try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)

self.sm = sacremoses
34 changes: 30 additions & 4 deletions src/transformers/models/xlm/tokenization_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import unicodedata
from typing import List, Optional, Tuple

import sacremoses as sm

from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging

Expand Down Expand Up @@ -629,6 +627,16 @@ def __init__(
**kwargs,
)

try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)

self.sm = sacremoses

# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
Expand Down Expand Up @@ -659,15 +667,15 @@ def do_lower_case(self):

def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
return punct_normalizer.normalize(text)

def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
Expand Down Expand Up @@ -970,3 +978,21 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
index += 1

return vocab_file, merge_file

def __getstate__(self):
state = self.__dict__.copy()
state["sm"] = None
return state

def __setstate__(self, d):
self.__dict__ = d

try:
import sacremoses
except ImportError:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)

self.sm = sacremoses

0 comments on commit 30ca529

Please sign in to comment.