-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathhelper.py
82 lines (62 loc) · 2.59 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import copy
import json
import random
import re
def _tokenize(x):
tokens = [v for v in re.findall(r"\w+|[^\w]", x, re.UNICODE) if len(v)] # fix last hanging space
token_shifts = []
char_token_map = []
c, j = 0, 0
for v in tokens:
if v.strip():
token_shifts.append(j)
j += 1
else:
token_shifts.append(-1)
char_token_map += [token_shifts[-1]] * len(v)
# remove empty word and extra space in tokens
tokens = [v.strip() for v in tokens if v.strip()]
assert len(tokens) == max(char_token_map) + 1, \
'num tokens must equal to the max char_token_map, but %d vs %d' % (len(tokens), max(char_token_map))
assert len(char_token_map) == len(x), \
'length of char_token_map must equal to original string, but %d vs %d' % (len(char_token_map), len(x))
return tokens, char_token_map
def _char_token_start_end(char_start, answer_text, char_token_map, full_tokens=None):
# to get the tokens use [start: (end+1)]
start_id = char_token_map[char_start]
end_id = char_token_map[char_start + len(answer_text) - 1]
if full_tokens:
ans = ' '.join(full_tokens[start_id: (end_id + 1)])
ans_gold = ' '.join(_tokenize(answer_text)[0])
assert ans == ans_gold, 'answers are not identical "%s" vs "%s"' % (ans, ans_gold)
return start_id, end_id
def _dump_to_json(sample):
return json.dumps(sample).encode()
def _load_from_json(batch):
return [json.loads(d) for d in batch]
def _parse_line(line):
return json.loads(line.strip())
def _do_padding(token_ids, token_lengths, pad_id):
pad_len = max(token_lengths)
return [(ids + [pad_id] * (pad_len - len(ids)))[: pad_len] for ids in token_ids]
def _do_char_padding(char_ids, token_lengths, pad_id, char_pad_id):
pad_token_len = max(token_lengths)
pad_char_len = max(len(xx) for x in char_ids for xx in x)
pad_empty_token = [char_pad_id] * pad_char_len
return [[(ids + [pad_id] * (pad_char_len - len(ids)))[: pad_char_len] for ids in x] +
[pad_empty_token] * (pad_token_len - len(x)) for x in char_ids]
def _dropout_word(x, unk_id, dropout_keep_prob):
return [v if random.random() < dropout_keep_prob else unk_id for v in x]
def _fast_copy(x, ignore_keys):
y = {}
for k, v in x.items():
if k in ignore_keys:
y[k] = v
else:
y[k] = copy.deepcopy(v)
return y
def build_vocab(embd_files):
from utils.vocab import Vocab
if embd_files[0].endswith('pickle'):
return Vocab.load_from_pickle(embd_files[0])
return Vocab(embd_files, lower=True)