-
Notifications
You must be signed in to change notification settings - Fork 7
/
datas_utilis.py
89 lines (79 loc) · 2.71 KB
/
datas_utilis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import tensorflow as tf
import pickle
import copy
def batch_generator(arr,num_seqs,num_steps):
'''
将一段文本数据转换为(num_seqs,len(arr)/num_seqs)
eg:12345678 num_seqs=2
1234
5678
输出数据shape(num_seqs,num_steps)
'''
arr=copy.copy(arr)
batch_size=num_seqs*num_steps
n_batches=int(len(arr)/batch_size)
arr=arr[:batch_size*n_batches]
arr=arr.reshape((num_seqs,-1))
while True:
np.random.shuffle(arr)
for n in range(0, arr.shape[1], num_steps):
x = arr[:, n:n+num_steps]
y = np.zeros_like(x)
#### x的后一个字符作为y的字符
y[:, :-1], y[:, -1]=x[:,1:], x[:, 0]
yield x, y
#### 文本数据向量化
class TextConverter(object):
def __init__(self, text=None, max_vocab=5000, filename=None):
if filename is not None:
with open(filename, 'rb') as f:
self.vocab = pickle.load(f)
else:
vocab = set(text)
print(len(vocab))
#### 找出最大的max_vocab个单词
vocab_count = {}
for word in vocab:
vocab_count[word] = 0
for word in text:
vocab_count[word] += 1
vocab_count_list = []
for word in vocab_count:
vocab_count_list.append((word, vocab_count[word]))
vocab_count_list.sort(key=lambda x: x[1], reverse=True)
if len(vocab_count_list) > max_vocab:
vocab_count_list = vocab_count_list[:max_vocab]
vocab = [x[0] for x in vocab_count_list]
self.vocab = vocab
self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)}
self.int_to_word_table = dict(enumerate(self.vocab))
@property
def vocab_size(self):
return len(self.vocab) + 1
def word_to_int(self, word):
if word in self.word_to_int_table:
return self.word_to_int_table[word]
else:
return len(self.vocab)
def int_to_word(self, index):
index=int(index)
if index == len(self.vocab):
return '<unk>'
elif index < len(self.vocab):
return self.int_to_word_table[index]
else:
raise Exception('Unknown index!')
def text_to_arr(self, text):
arr = []
for word in text:
arr.append(self.word_to_int(word))
return np.array(arr)
def arr_to_text(self, arr):
words = []
for index in arr:
words.append(self.int_to_word(index))
return "".join(words)
def save_to_file(self, filename):
with open(filename, 'wb') as f:
pickle.dump(self.vocab, f)