forked from baidu/Senta
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
112 lines (97 loc) · 2.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import numpy as np
import random
import paddle.fluid as fluid
import paddle
import io
def get_predict_label(pos_prob):
neg_prob = 1 - pos_prob
# threshold should be (1, 0.5)
neu_threshold = 0.55
if neg_prob > neu_threshold:
class3_label = 0
elif pos_prob > neu_threshold:
class3_label = 2
else:
class3_label = 1
if pos_prob >= neg_prob:
class2_label = 2
else:
class2_label = 0
return class3_label, class2_label
def to_lodtensor(data, place):
"""
convert ot LODtensor
"""
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def data2tensor(data, place):
"""
data2tensor
"""
input_seq = to_lodtensor(list(map(lambda x: x[0], data)), place)
return {"words": input_seq}
def data_reader(file_path, word_dict, is_shuffle=True):
"""
Convert word sequence into slot
"""
unk_id = len(word_dict)
all_data = []
with io.open(file_path, "r", encoding='utf8') as fin:
for line in fin:
cols = line.strip().split("\t")
label = int(cols[0])
wids = [word_dict[x] if x in word_dict else unk_id
for x in cols[1].split(" ")]
all_data.append((wids, label))
if is_shuffle:
random.shuffle(all_data)
def reader():
for doc, label in all_data:
yield doc, label
return reader
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(file_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def prepare_data(data_path, word_dict_path,
batch_size, mode):
"""
prepare data
"""
assert os.path.exists(
word_dict_path), "The given word dictionary dose not exist."
if mode == "train":
assert os.path.exists(
data_path), "The given training data does not exist."
if mode == "eval" or mode == "infer":
assert os.path.exists(
data_path), "The given test data does not exist."
word_dict = load_vocab(word_dict_path)
if mode == "train":
train_reader = paddle.batch(data_reader(data_path, word_dict, True),
batch_size)
return word_dict, train_reader
else:
test_reader = paddle.batch(data_reader(data_path, word_dict, False),
batch_size)
return word_dict, test_reader