-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
112 lines (92 loc) · 3.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import re
def get_2d_spans(text, tokenss):
spanss = []
cur_idx = 0
for tokens in tokenss:
spans = []
for token in tokens:
if text.find(token, cur_idx) < 0:
print(tokens)
print("{} {} {}".format(token, cur_idx, text))
raise Exception()
cur_idx = text.find(token, cur_idx)
spans.append((cur_idx, cur_idx + len(token)))
cur_idx += len(token)
spanss.append(spans)
return spanss
def get_word_span(context, wordss, start, stop):
spanss = get_2d_spans(context, wordss)
idxs = []
for sent_idx, spans in enumerate(spanss):
for word_idx, span in enumerate(spans):
if not (stop <= span[0] or start >= span[1]):
idxs.append((sent_idx, word_idx))
assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop)
return idxs[0], (idxs[-1][0], idxs[-1][1] + 1)
def get_phrase(context, wordss, span):
"""
Obtain phrase as substring of context given start and stop indices in word level
:param context:
:param wordss:
:param start: [sent_idx, word_idx]
:param stop: [sent_idx, word_idx]
:return:
"""
start, stop = span
flat_start = get_flat_idx(wordss, start)
flat_stop = get_flat_idx(wordss, stop)
words = sum(wordss, [])
char_idx = 0
char_start, char_stop = None, None
for word_idx, word in enumerate(words):
char_idx = context.find(word, char_idx)
assert char_idx >= 0
if word_idx == flat_start:
char_start = char_idx
char_idx += len(word)
if word_idx == flat_stop - 1:
char_stop = char_idx
assert char_start is not None
assert char_stop is not None
return context[char_start:char_stop]
def get_flat_idx(wordss, idx):
return sum(len(words) for words in wordss[:idx[0]]) + idx[1]
def get_word_idx(context, wordss, idx):
spanss = get_2d_spans(context, wordss)
return spanss[idx[0]][idx[1]][0]
def process_tokens(temp_tokens):
tokens = []
for token in temp_tokens:
flag = False
l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
# \u2013 is en-dash. Used for number to nubmer
# l = ("-", "\u2212", "\u2014", "\u2013")
# l = ("\u2013",)
tokens.extend(re.split("([{}])".format("".join(l)), token))
return tokens
def get_best_span(ypi, yp2i):
max_val = 0
best_word_span = (0, 1)
best_sent_idx = 0
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
argmax_j1 = 0
for j in range(len(ypif)):
val1 = ypif[argmax_j1]
if val1 < ypif[j]:
val1 = ypif[j]
argmax_j1 = j
val2 = yp2if[j]
if val1 * val2 > max_val:
best_word_span = (argmax_j1, j)
best_sent_idx = f
max_val = val1 * val2
return ((best_sent_idx, best_word_span[0]), (best_sent_idx, best_word_span[1] + 1)), float(max_val)
def get_span_score_pairs(ypi, yp2i):
span_score_pairs = []
for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
for j in range(len(ypif)):
for k in range(j, len(yp2if)):
span = ((f, j), (f, k+1))
score = ypif[j] * yp2if[k]
span_score_pairs.append((span, score))
return span_score_pairs