Skip to content

Commit

Permalink
thread remake
Browse files Browse the repository at this point in the history
  • Loading branch information
ximinng committed May 21, 2019
1 parent aafcf7c commit 651b122
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ ENV/
# large file
*.pkl
*.conv
/raw_data/

# trained model
model/
Expand Down
6 changes: 3 additions & 3 deletions src/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import numpy as np
from tensorflow.python.client import device_lib
from src.word_sequence import WordSequence
from word_sequence import WordSequence

# 处理词向量临界值
VOCAB_SIZE_THRESHOLD_GPU = 50000
Expand Down Expand Up @@ -216,15 +216,15 @@ def batch_flow_bucket(data, ws, batch_size, raw=False, add_end=True, n_bucket=5,


def test_batch_flow():
from src.fake_data import generate
from fake_data import generate
x_data, y_data, ws_input, ws_target = generate(size=10000)
flow = batch_flow([x_data, y_data], [ws_input, ws_target], 4)
x, xl, y, yl = next(flow)
print(x.shape, y.shape, xl.shape, yl.shape)


def test_batch_flow_bucket():
from src.fake_data import generate
from fake_data import generate
x_data, y_data, ws_input, ws_target = generate(size=10000)
flow = batch_flow_bucket([x_data, y_data], [ws_input, ws_target], batch_size=4, debug=True)
for _ in range(10):
Expand Down
8 changes: 4 additions & 4 deletions src/extract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm


# 去掉非法字符
# 去掉非法字符,合并句子
def make_split(line):
if re.match(r'.*([,…?!\.,!?])$', ''.join(line)):
return []
Expand All @@ -35,11 +35,11 @@ def regular(sen):


def main(limit=20, x_limit=3, y_limit=6):
from src.word_sequence import WordSequence
from word_sequence import WordSequence

# 解压文件
print('extract lines')
fp = open("dgk_shooter_min.conv", 'r', errors='ignore', encoding='utf-8')
fp = open("raw_data/dgk_shooter_min.conv", 'r', errors='ignore', encoding='utf-8')
# 保存全部句子列表
groups = []
# 保存一行
Expand All @@ -59,7 +59,7 @@ def main(limit=20, x_limit=3, y_limit=6):
line = line[:-1]

group.append(list(regular(''.join(line))))
# E开头句子
# E开头句子---line.startswith('E ')
else:
if group:
groups.append(group)
Expand Down
2 changes: 1 addition & 1 deletion src/fake_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import random
import numpy as np
from src.word_sequence import WordSequence
from word_sequence import WordSequence


def generate(max_len=10, size=1000, same_len=False, seed=0):
Expand Down
4 changes: 2 additions & 2 deletions src/seq_to_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tensorflow.contrib.rnn import ResidualWrapper
# from tensorflow.contrib.rnn import LSTMStateTuple

from src.word_sequence import WordSequence
from src.data_utils import _get_embed_device
from word_sequence import WordSequence
from data_utils import _get_embed_device


class SequenceToSequence(object):
Expand Down
4 changes: 2 additions & 2 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


def test(params):
from src.seq_to_seq import SequenceToSequence
from src.data_utils import batch_flow
from seq_to_seq import SequenceToSequence
from data_utils import batch_flow

x_data, _ = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))
Expand Down
10 changes: 8 additions & 2 deletions src/thread_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# -*- coding: utf-8 -*-
"""
Description : 线程生成器
Author : xxm
"""
from threading import Thread
from queue import Queue


class ThreadedGenerator(object):

def __init__(self, iterator,
Expand Down Expand Up @@ -56,8 +62,8 @@ def __next__(self):
raise StopIteration()
return value

def test():

def test():
def gene():
i = 0
while True:
Expand All @@ -72,6 +78,6 @@ def gene():

test.close()


if __name__ == '__main__':
test()

14 changes: 7 additions & 7 deletions src/train_anti.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from tqdm import tqdm


def test(params):
from src.seq_to_seq import SequenceToSequence
from src.data_utils import batch_flow_bucket as batch_flow
from src.word_sequence import WordSequence
from src.thread_generator import ThreadedGenerator
def train(params):
from seq_to_seq import SequenceToSequence
from data_utils import batch_flow_bucket as batch_flow
from word_sequence import WordSequence
from thread_generator import ThreadedGenerator

# 加载数据
x_data, y_data = pickle.load(open('chatbot.pkl', 'rb'))
Expand Down Expand Up @@ -89,7 +89,6 @@ def test(params):
model.save(sess, save_path)
flow.close()

# 训练2
tf.reset_default_graph()
model_pred = SequenceToSequence(
input_vocab_size=len(ws),
Expand Down Expand Up @@ -121,6 +120,7 @@ def test(params):
if t >= 3:
break

# 训练2
tf.reset_default_graph()
model_pred = SequenceToSequence(
input_vocab_size=len(ws),
Expand Down Expand Up @@ -155,7 +155,7 @@ def test(params):

def main():
import json
test(json.load(open('params.json')))
train(json.load(open('params.json')))


if __name__ == '__main__':
Expand Down

0 comments on commit 651b122

Please sign in to comment.