forked from MIR-MU/regemt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ood_metrics.py
140 lines (111 loc) · 5.68 KB
/
ood_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Dict, Tuple, List, Callable, Iterable, Any, Optional
from functools import lru_cache
from contextlib import redirect_stdout, redirect_stderr
import os
import numpy as np
import spacy
from tqdm.autonotebook import tqdm
from common import Judgements, ReferenceFreeMetric
taggers: Dict[str, Callable[[str], Iterable[Tuple[str, str]]]] = {}
class TransitionModel:
def __init__(self, dataset: List[str], lang: str):
if lang not in taggers:
tagger = self._init_tagger(lang)
taggers[lang] = tagger
else:
tagger = taggers[lang]
self.corpus_words_tagged = [list(tagger(dataset[i])) for i in range(len(dataset))]
corpus_tags = [[tag for token, tag in tagged_seq] for tagged_seq in self.corpus_words_tagged]
self.all_tags, self.transition_probs = self._transition_graph_from_tags(corpus_tags)
def distance(self, other: "TransitionModel") -> float:
matching_indices = [i for i, tag in enumerate(self.all_tags) if tag in other.all_tags]
self_transitions_subset = self.transition_probs[matching_indices, :][:, matching_indices]
other_matching_indices = [i for i, tag in enumerate(other.all_tags) if tag in self.all_tags]
other_transitions = other.transition_probs
other_transitions_subset = other_transitions[other_matching_indices, :][:, other_matching_indices]
return np.linalg.norm(self_transitions_subset - other_transitions_subset)
@staticmethod
def _transition_graph_from_tags(tag_sequences: List[List[str]]) -> Tuple[List[str], np.ndarray]:
# construct 2-grams from sequences of tags and count an occurrence of each 2-gram for the transition graph
counts: Dict[Tuple[str, str], int] = {}
for sequence in tag_sequences:
for i in range(2, len(sequence)):
tags_from_to = tuple(sequence[i-2:i])
try:
counts[tags_from_to] += 1
except KeyError:
counts[tags_from_to] = 1
all_tags = sorted(set([k[0] for k in counts.keys()] + [k[0] for k in counts.keys()]))
transition_matrix = [[counts.get((tag_x, tag_y), 0) for tag_x in all_tags] for tag_y in all_tags]
if not transition_matrix:
# text is a single-word tag - can happen in initial training phases
# we need to keep the dimensionality
transition_matrix = [[]]
transition_matrix_np = np.array(transition_matrix)
return all_tags, transition_matrix_np / max(transition_matrix_np.sum(), 1)
@staticmethod
def supports(lang: str) -> bool:
return lang in ("no", "en", "de", "zh")
@staticmethod
def _init_tagger(lang: str) -> Callable[[str], Iterable[Tuple[str, str]]]:
if lang == "no":
model_id = "nb_core_news_md"
elif lang == "en":
model_id = "en_core_web_md"
elif lang == "de":
model_id = "de_core_news_md"
elif lang == "zh":
model_id = "zh_core_web_md"
else:
raise ValueError("Language '%s' has no defined tagger" % lang)
with open(os.devnull, 'w') as f, redirect_stdout(f), redirect_stderr(f):
try:
spacy_tagger = spacy.load(model_id)
except OSError:
# tagger not-yet downloaded
# spacy.cli.download(model_id, False, "-q")
spacy.cli.download(model_id)
spacy_tagger = spacy.load(model_id)
def _spacy_pos_tagger_wrapper(text: str) -> Iterable[Tuple[str, str]]:
tokens_tagged = spacy_tagger(text)
for token in tokens_tagged:
yield token.text, token.pos_
return _spacy_pos_tagger_wrapper
class SyntacticCompositionality(ReferenceFreeMetric):
pos_tagger: Callable[[str], Tuple[str, str]]
src_lang: Optional[str] = None
label = "Compositionality"
def __init__(self, tgt_lang: str, src_lang: Optional[str] = None, reference_free: bool = False):
"""
Compares syntactic compositionality's perplexity on train distribution and outer distribution.
Syntactic compositionality is a transition matrix of PoS tags
"""
self.tgt_lang = tgt_lang
self.reference_free = reference_free
if reference_free:
self.src_lang = src_lang
@staticmethod
def supports(src_lang: str, tgt_lang: str, reference_free: bool) -> bool:
return TransitionModel.supports(tgt_lang) and (not reference_free or TransitionModel.supports(src_lang))
@lru_cache(maxsize=None)
def compute(self, judgements: Judgements) -> List[float]:
if self.reference_free:
base_transitions = [TransitionModel([src_text], self.src_lang)
for src_text in tqdm(judgements.src_texts, desc=self.label)]
else:
base_transitions = [TransitionModel([ref_texts[0]], self.tgt_lang)
for ref_texts in tqdm(judgements.references, desc=self.label)]
translated_model = [TransitionModel([t_text], self.tgt_lang)
for t_text in tqdm(judgements.translations, desc=self.label)]
distances = [base_t.distance(translated_t) for base_t, translated_t in zip(base_transitions, translated_model)]
return distances
def __eq__(self, other: Any) -> bool:
if not isinstance(other, SyntacticCompositionality):
return NotImplemented
return all([
self.reference_free == other.reference_free,
self.src_lang == other.src_lang,
self.tgt_lang == other.tgt_lang,
])
def __hash__(self) -> int:
return hash((self.reference_free, self.src_lang, self.tgt_lang))