Skip to content

Commit

Permalink
Use model config instead for schema tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Dec 14, 2020
1 parent 62679e7 commit 589ef31
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(
vocab_file: str = None,
merges_file: str = None,
tokenizer_file: str = None,
schema_tokens: List[str] = None,
schema_return: List[str] = None,
cache_dir: str = "aitextgen",
tf_gpt2: str = None,
to_gpu: bool = False,
Expand Down Expand Up @@ -179,6 +181,12 @@ def __init__(
cache_dir=cache_dir,
)

if schema_tokens:
self.model.config["schema_tokens"] = schema_tokens

if schema_tokens:
self.model.config["schema_return"] = schema_return

if self.tokenizer is None:
# Update tokenizer settings (if not set already)
args = locals()
Expand Down Expand Up @@ -210,14 +218,6 @@ def __init__(
unk_token=self.unk_token,
pad_token=self.pad_token,
)
with open(tokenizer_file, "r", encoding="utf-8") as f:
data = json.load(f)
self.schema_tokens = {
x["id"]: x["content"]
for x in data["added_tokens"]
if x["content"]
not in [self.bos_token, self.eos_token, self.unk_token]
}
else:
self.tokenizer = GPT2TokenizerFast(
vocab_file=self.vocab_file,
Expand Down Expand Up @@ -250,6 +250,9 @@ def generate(
return_as_list: bool = False,
seed: int = None,
pad_token_id: str = None,
schema: str = None,
schema_tokens: List[str] = None,
schema_return: List[str] = None,
**kwargs,
) -> Optional[str]:
"""
Expand Down

0 comments on commit 589ef31

Please sign in to comment.