forked from facebookresearch/GENRE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_example.py
43 lines (33 loc) · 1.2 KB
/
test_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pickle
from unittest.mock import ANY
import pytest
from genre.trie import Trie
from genre.fairseq_model import GENRE, GENREHubInterface
@pytest.fixture(scope="session")
def kilt_trie():
# load the prefix tree (trie)
with open("./data/kilt_titles_trie_dict.pkl", "rb") as f:
trie = Trie.load_from_dict(pickle.load(f))
return trie
@pytest.fixture(scope="session")
def fairseq_wikipage_retrieval():
model = GENRE.from_pretrained("./models/fairseq_wikipage_retrieval").eval()
return model
EXPECTED_RESULTS_DOCUMENT_RETRIEVAL = [
[
{"text": "Albert Einstein", "score": ANY},
{"text": "Werner Bruschke", "score": ANY},
{"text": "Werner von Habsburg", "score": ANY},
{"text": "Werner von Moltke", "score": ANY},
{"text": "Werner von Eichstedt", "score": ANY},
]
]
def test_example_document_retrieval(
kilt_trie: Trie, fairseq_wikipage_retrieval: GENREHubInterface
):
sentences = ["Einstein was a German physicist."]
results = fairseq_wikipage_retrieval.sample(
sentences,
prefix_allowed_tokens_fn=lambda batch_id, sent: kilt_trie.get(sent.tolist()),
)
assert results == EXPECTED_RESULTS_DOCUMENT_RETRIEVAL