Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#128 from lcy-seso/clean_rnn_lm_codes
Browse files Browse the repository at this point in the history
refactor codes of the language model example.
  • Loading branch information
lcy-seso authored Jun 28, 2017
2 parents 08ab956 + 7d3a8cd commit c075ae2
Show file tree
Hide file tree
Showing 23 changed files with 998 additions and 940 deletions.
3 changes: 3 additions & 0 deletions generate_sequence_by_rnn_lm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.pyc
*.tar.gz
models
162 changes: 162 additions & 0 deletions generate_sequence_by_rnn_lm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# 使用循环神经网语言模型生成文本

语言模型(Language Model)是一个概率分布模型,简单来说,就是用来计算一个句子的概率的模型。利用它可以确定哪个词序列的可能性更大,或者给定若干个词,可以预测下一个最可能出现的词。语言模型是自然语言处理领域里一个重要的基础模型。

## 应用场景
**语言模型被应用在很多领域**,如:

* **自动写作**:语言模型可以根据上文生成下一个词,递归下去可以生成整个句子、段落、篇章。
* **QA**:语言模型可以根据Question生成Answer。
* **机器翻译**:当前主流的机器翻译模型大多基于Encoder-Decoder模式,其中Decoder就是一个待条件的语言模型,用来生成目标语言。
* **拼写检查**:语言模型可以计算出词序列的概率,一般在拼写错误处序列的概率会骤减,可以用来识别拼写错误并提供改正候选集。
* **词性标注、句法分析、语音识别......**

## 关于本例
本例实现基于RNN的语言模型,以及利用语言模型生成文本,本例的目录结构如下:

```text
.
├── data
│ └── train_data_examples.txt # 示例数据,可参考示例数据的格式,提供自己的数据
├── config.py # 配置文件,包括data、train、infer相关配置
├── generate.py # 预测任务脚本,即生成文本
├── beam_search.py # beam search 算法实现
├── network_conf.py # 本例中涉及的各种网络结构均定义在此文件中,希望进一步修改模型结构,请修改此文件
├── reader.py # 读取数据接口
├── README.md
├── train.py # 训练任务脚本
└── utils.py # 定义通用的函数,例如:构建字典、加载字典等
```

## RNN 语言模型
### 简介

RNN是一个序列模型,基本思路是:在时刻$t$,将前一时刻$t-1$的隐藏层输出和$t$时刻的词向量一起输入到隐藏层从而得到时刻$t$的特征表示,然后用这个特征表示得到$t$时刻的预测输出,如此在时间维上递归下去。可以看出RNN善于使用上文信息、历史知识,具有“记忆”功能。理论上RNN能实现“长依赖”(即利用很久之前的知识),但在实际应用中发现效果并不理想,研究提出了LSTM和GRU等变种,通过引入门机制对传统RNN的记忆单元进行了改进,弥补了传统RNN在学习长序列时遇到的难题。本例模型使用了LSTM或GRU,可通过配置进行修改。下图是RNN(广义上包含了LSTM、GRU等)语言模型“循环”思想的示意图:

<p align=center><img src='images/rnn.png' width='500px'/></p>

### 模型实现

本例中RNN语言模型的实现简介如下:

- **定义模型参数**`config.py`中定义了模型的参数变量。
- **定义模型结构**`network_conf.py`中的`rnn_lm`**函数**中定义了模型的**结构**,如下:
- 输入层:将输入的词(或字)序列映射成向量,即词向量层: `embedding`
- 中间层:根据配置实现RNN层,将上一步得到的`embedding`向量序列作为输入。
- 输出层:使用`softmax`归一化计算单词的概率。
- loss:定义多类交叉熵作为模型的损失函数。
- **训练模型**`train.py`中的`main`方法实现了模型的训练,实现流程如下:
- 准备输入数据:建立并保存词典、构建train和test数据的reader。
- 初始化模型:包括模型的结构、参数。
- 构建训练器:demo中使用的是Adam优化算法。
- 定义回调函数:构建`event_handler`来跟踪训练过程中loss的变化,并在每轮训练结束时保存模型的参数。
- 训练:使用trainer训练模型。

- **生成文本**`generate.py` 实现了文本的生成,实现流程如下:
- 加载训练好的模型和词典文件。
- 读取`gen_file`文件,每行是一个句子的前缀,用[柱搜索算法(Beam Search)](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.cn.md#柱搜索算法)根据前缀生成文本。
- 将生成的文本及其前缀保存到文件`gen_result`

## 使用说明

运行本例的方法如下:

* 1,运行`python train.py`命令,开始train模型(默认使用RNN),待训练结束。
* 2,运行`python generate.py`运行文本生成。(输入的文本默认为`data/train_data_examples.txt`,生成的文本默认保存到`data/gen_result.txt`中。)


**如果需要使用自己的语料、定制模型,需要修改`config.py`中的配置,细节和适配工作详情如下:**


### 语料适配

* 清洗语料:去除原文中空格、tab、乱码,按需去除数字、标点符号、特殊符号等。
* 内容格式:每个句子占一行;每行中的各词之间使用一个空格符分开。
* 按需要配置`config.py`中的如下参数:

```python
train_file = "data/train_data_examples.txt"
test_file = ""

vocab_file = "data/word_vocab.txt"
model_save_dir = "models"
```
1. `train_file`:指定训练数据的路径,**需要预先分词**
2. `test_file`:指定测试数据的路径,如果训练数据不为空,将在每个 `pass` 训练结束对指定的测试数据进行测试。
3. `vocab_file`:指定字典的路径,如果字典文件不存在,将会对训练语料进行词频统计,构建字典。
4. `model_save_dir`:指定模型保存的路径,如果指定的文件夹不存在,将会自动创建。

### 构建字典的策略
- 当指定的字典文件不存在时,将对训练数据进行词频统计,自动构建字典`config.py` 中有如下两个参数与构建字典有关:

```python
max_word_num = 51200 - 2
cutoff_word_fre = 0
```
1. `max_word_num`:指定字典中含有多少个词。
2. `cutoff_word_fre`:字典中词语在训练语料中出现的最低频率。
- 加入指定了 `max_word_num = 5000`,并且 `cutoff_word_fre = 10`,词频统计发现训练语料中出现频率高于10次的词语仅有3000个,那么最终会取3000个词构成词典。
- 构建词典时,会自动加入两个特殊符号:
1. `<unk>`:不出现在字典中的词
2. `<e>`:句子的结束符

*注:需要注意的是,词典越大生成的内容越丰富,但训练耗时越久。一般中文分词之后,语料中不同的词能有几万乃至几十万,如果`max_word_num`取值过小则导致`<unk>`占比过高,如果`max_word_num`取值较大,则严重影响训练速度(对精度也有影响)。所以,也有“按字”训练模型的方式,即:把每个汉字当做一个词,常用汉字也就几千个,使得字典的大小不会太大、不会丢失太多信息,但汉语中同一个字在不同词中语义相差很大,有时导致模型效果不理想。建议多试试、根据实际情况选择是“按词训练”还是“按字训练”。*

### 模型适配、训练

* 按需调整`config.py`中如下配置,来修改 rnn 语言模型的网络结果:

```python
rnn_type = "lstm" # "gru" or "lstm"
emb_dim = 256
hidden_size = 256
stacked_rnn_num = 2
```
1. `rnn_type`:支持 ”gru“ 或者 ”lstm“ 两种参数,选择使用何种 RNN 单元。
2. `emb_dim`:设置词向量的维度。
3. `hidden_size`:设置 RNN 单元隐层大小。
4. `stacked_rnn_num`:设置堆叠 RNN 单元的个数,构成一个更深的模型。

* 运行`python train.py`命令训练模型,模型将被保存到`model_save_dir`指定的目录。

### 按需生成文本

* 按需调整`config.py`中以下变量,详解如下:

```python
gen_file = "data/train_data_examples.txt"
gen_result = "data/gen_result.txt"
max_gen_len = 25 # the max number of words to generate
beam_size = 5
model_path = "models/rnn_lm_pass_00000.tar.gz"
```
1. `gen_file`:指定输入数据文件,每行是一个句子的前缀,**需要预先分词**
2. `gen_result`:指定输出文件路径,生成结果将写入此文件。
3. `max_gen_len`:指定每一句生成的话最长长度,如果模型无法生成出`<e>`,当生成 `max_gen_len` 个词语后,生成过程会自动终止。
4. `beam_size`:Beam Search 算法每一步的展开宽度。
5. `model_path`:指定训练好的模型的路径。

其中,`gen_file` 中保存的是待生成的文本前缀,每个前缀占一行,形如:

```text
若隐若现 地像 幽灵 , 像 死神
```
将需要生成的文本前缀按此格式存入文件即可;

* 运行`python generate.py`命令运行beam search 算法为输入前缀生成文本,下面是模型生成的结果:

```text
81 若隐若现 地像 幽灵 , 像 死神
-12.2542 一样 。 他 是 个 怪物 <e>
-12.6889 一样 。 他 是 个 英雄 <e>
-13.9877 一样 。 他 是 我 的 敌人 <e>
-14.2741 一样 。 他 是 我 的 <e>
-14.6250 一样 。 他 是 我 的 朋友 <e>
```
其中:
1. 第一行 `81 若隐若现 地像 幽灵 , 像 死神``\t`为分隔,共有两列:
- 第一列是输入前缀在训练样本集中的序号。
- 第二列是输入的前缀。
2. 第二 ~ `beam_size + 1` 行是生成结果,同样以 `\t` 分隔为两列:
- 第一列是该生成序列的对数概率(log probability)。
- 第二列是生成的文本序列,正常的生成结果会以符号`<e>`结尾,如果没有以`<e>`结尾,意味着超过了最大序列长度,生成强制终止。
174 changes: 174 additions & 0 deletions generate_sequence_by_rnn_lm/beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env python
# coding=utf-8
import os
import math
import numpy as np

import paddle.v2 as paddle

from utils import logger, load_reverse_dict

__all__ = ["BeamSearch"]


class BeamSearch(object):
"""
Generating sequence by beam search
NOTE: this class only implements generating one sentence at a time.
"""

def __init__(self, inferer, word_dict_file, beam_size=1, max_gen_len=100):
"""
constructor method.
:param inferer: object of paddle.Inference that represents the entire
network to forward compute the test batch
:type inferer: paddle.Inference
:param word_dict_file: path of word dictionary file
:type word_dict_file: str
:param beam_size: expansion width in each iteration
:type param beam_size: int
:param max_gen_len: the maximum number of iterations
:type max_gen_len: int
"""
self.inferer = inferer
self.beam_size = beam_size
self.max_gen_len = max_gen_len
self.ids_2_word = load_reverse_dict(word_dict_file)
logger.info("dictionay len = %d" % (len(self.ids_2_word)))

try:
self.eos_id = next(x[0] for x in self.ids_2_word.iteritems()
if x[1] == "<e>")
self.unk_id = next(x[0] for x in self.ids_2_word.iteritems()
if x[1] == "<unk>")
except StopIteration:
logger.fatal(("the word dictionay must contain an ending mark "
"in the text generation task."))

self.candidate_paths = []
self.final_paths = []

def _top_k(self, softmax_out, k):
"""
get indices of the words with k highest probablities.
NOTE: <unk> will be excluded if it is among the top k words, then word
with (k + 1)th highest probability will be returned.
:param softmax_out: probablity over the dictionary
:type softmax_out: narray
:param k: number of word indices to return
:type k: int
:return: indices of k words with highest probablities.
:rtype: list
"""
ids = softmax_out.argsort()[::-1]
return ids[ids != self.unk_id][:k]

def _forward_batch(self, batch):
"""
forward a test batch.
:params batch: the input data batch
:type batch: list
:return: probablities of the predicted word
:rtype: ndarray
"""
return self.inferer.infer(input=batch, field=["value"])

def _beam_expand(self, next_word_prob):
"""
In every iteration step, the model predicts the possible next words.
For each input sentence, the top k words is added to end of the original
sentence to form a new generated sentence.
:param next_word_prob: probablities of the next words
:type next_word_prob: ndarray
:return: the expanded new sentences.
:rtype: list
"""
assert len(next_word_prob) == len(self.candidate_paths), (
"Wrong forward computing results!")
top_beam_words = np.apply_along_axis(self._top_k, 1, next_word_prob,
self.beam_size)
new_paths = []
for i, words in enumerate(top_beam_words):
old_path = self.candidate_paths[i]
for w in words:
log_prob = old_path["log_prob"] + math.log(next_word_prob[i][w])
gen_ids = old_path["ids"] + [w]
if w == self.eos_id:
self.final_paths.append({
"log_prob": log_prob,
"ids": gen_ids
})
else:
new_paths.append({"log_prob": log_prob, "ids": gen_ids})
return new_paths

def _beam_shrink(self, new_paths):
"""
to return the top beam_size generated sequences with the highest
probabilities at the end of evey generation iteration.
:param new_paths: all possible generated sentences
:type new_paths: list
:return: a state flag to indicate whether to stop beam search
:rtype: bool
"""

if len(self.final_paths) >= self.beam_size:
max_candidate_log_prob = max(
new_paths, key=lambda x: x["log_prob"])["log_prob"]
min_complete_path_log_prob = min(
self.final_paths, key=lambda x: x["log_prob"])["log_prob"]
if min_complete_path_log_prob >= max_candidate_log_prob:
return True

new_paths.sort(key=lambda x: x["log_prob"], reverse=True)
self.candidate_paths = new_paths[:self.beam_size]
return False

def gen_a_sentence(self, input_sentence):
"""
generating sequence for an given input
:param input_sentence: one input_sentence
:type input_sentence: list
:return: the generated word sequences
:rtype: list
"""
self.candidate_paths = [{"log_prob": 0., "ids": input_sentence}]
input_len = len(input_sentence)

for i in range(self.max_gen_len):
next_word_prob = self._forward_batch(
[[x["ids"]] for x in self.candidate_paths])
new_paths = self._beam_expand(next_word_prob)

min_candidate_log_prob = min(
new_paths, key=lambda x: x["log_prob"])["log_prob"]

path_to_remove = [
path for path in self.final_paths
if path["log_prob"] < min_candidate_log_prob
]
for p in path_to_remove:
self.final_paths.remove(p)

if self._beam_shrink(new_paths):
self.candidate_paths = []
break

gen_ids = sorted(
self.final_paths + self.candidate_paths,
key=lambda x: x["log_prob"],
reverse=True)[:self.beam_size]
self.final_paths = []

def _to_str(x):
text = " ".join(self.ids_2_word[idx]
for idx in x["ids"][input_len:])
return "%.4f\t%s" % (x["log_prob"], text)

return map(_to_str, gen_ids)
46 changes: 46 additions & 0 deletions generate_sequence_by_rnn_lm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
# coding=utf-8
import os

################## for building word dictionary ##################

max_word_num = 51200 - 2
cutoff_word_fre = 0

################## for training task #########################
# path of training data
train_file = "data/train_data_examples.txt"
# path of testing data, if testing file does not exist,
# testing will not be performed at the end of each training pass
test_file = ""
# path of word dictionary, if this file does not exist,
# word dictionary will be built from training data.
vocab_file = "data/word_vocab.txt"
# directory to save the trained model
# create a new directory if the directoy does not exist
model_save_dir = "models"

batch_size = 32 # the number of training examples in one forward/backward pass
num_passes = 20 # how many passes to train the model

log_period = 50
save_period_by_batches = 50

use_gpu = True # to use gpu or not
trainer_count = 1 # number of trainer

################## for model configuration ##################
rnn_type = "lstm" # "gru" or "lstm"
emb_dim = 256
hidden_size = 256
stacked_rnn_num = 2

################## for text generation ##################
gen_file = "data/train_data_examples.txt"
gen_result = "data/gen_result.txt"
max_gen_len = 25 # the max number of words to generate
beam_size = 5
model_path = "models/rnn_lm_pass_00000.tar.gz"

if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
5 changes: 5 additions & 0 deletions generate_sequence_by_rnn_lm/data/train_data_examples.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
我们 不会 伤害 你 的 。 他们 也 这么 说 。
你 拥有 你 父亲 皇室 的 血统 。 是 合法 的 继承人 。
叫 什么 你 可以 告诉 我 。
你 并 没有 留言 说 要 去 哪里 。 是 的 , 因为 我 必须 要 去 完成 这件 事 。
你 查出 是 谁 住 在 隔壁 房间 吗 ?
Loading

0 comments on commit c075ae2

Please sign in to comment.