Skip to content

Commit

Permalink
Final conversion for HF mGENRE
Browse files Browse the repository at this point in the history
  • Loading branch information
nicola-decao committed Jun 8, 2022
1 parent a4d75ef commit 9c720f5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
15 changes: 12 additions & 3 deletions examples_mgenre/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ sentences = ["[START] Einstein [END] era un fisico tedesco."]
model.sample(
sentences,
prefix_allowed_tokens_fn=lambda batch_id, sent: [
e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)
e for e in trie.get(sent.tolist())
if e < len(model.task.target_dictionary)
# for huggingface/transformers
# if e < len(model2.tokenizer) - 1
],
)
```
Expand All @@ -94,7 +97,10 @@ Additionally, we can use the `lang_title2wikidataID` dictionary to map the gener
model.sample(
sentences,
prefix_allowed_tokens_fn=lambda batch_id, sent: [
e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)
e for e in trie.get(sent.tolist())
if e < len(model.task.target_dictionary)
# for huggingface/transformers
# if e < len(model2.tokenizer) - 1
],
text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(" >> ")))], key=lambda y: int(y[1:])),
marginalize=True,
Expand Down Expand Up @@ -155,7 +161,10 @@ trie_of_mention = Trie([
model.sample(
sentences,
prefix_allowed_tokens_fn=lambda batch_id, sent: [
e for e in trie_of_mention.get(sent.tolist()) if e < len(model.task.target_dictionary)
e for e in trie_of_mention.get(sent.tolist())
if e < len(model.task.target_dictionary)
# for huggingface/transformers
# if e < len(model2.tokenizer) - 1
],
text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(" >> ")))], key=lambda y: int(y[1:])),
marginalize=True,
Expand Down
15 changes: 12 additions & 3 deletions examples_mgenre/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@
"model.sample(\n",
" sentences,\n",
" prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
" e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)\n",
" e for e in trie.get(sent.tolist())\n",
" if e < len(model.task.target_dictionary)\n",
" # for huggingface/transformers\n",
" # if e < len(model2.tokenizer) - 1\n",
" ],\n",
")"
]
Expand Down Expand Up @@ -164,7 +167,10 @@
"model.sample(\n",
" sentences,\n",
" prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
" e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)\n",
" e for e in trie.get(sent.tolist())\n",
" if e < len(model.task.target_dictionary)\n",
" # for huggingface/transformers\n",
" # if e < len(model2.tokenizer) - 1\n",
" ],\n",
" text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(\" >> \")))], key=lambda y: int(y[1:])),\n",
" marginalize=True,\n",
Expand Down Expand Up @@ -249,7 +255,10 @@
"model.sample(\n",
" sentences,\n",
" prefix_allowed_tokens_fn=lambda batch_id, sent: [\n",
" e for e in trie_of_mention.get(sent.tolist()) if e < len(model.task.target_dictionary)\n",
" e for e in trie_of_mention.get(sent.tolist())\n",
" if e < len(model.task.target_dictionary)\n",
" # for huggingface/transformers\n",
" # if e < len(model2.tokenizer) - 1\n",
" ],\n",
" text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(\" >> \")))], key=lambda y: int(y[1:])),\n",
" marginalize=True,\n",
Expand Down

0 comments on commit 9c720f5

Please sign in to comment.