Skip to content

Commit

Permalink
调整了大量代码,优化框架.
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Oct 24, 2020
1 parent 5d932fa commit 1913fab
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 15 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pytorch实现bert做seq2seq任务,使用unilm方案。如果喜欢的话欢迎
多谢支持。另外,网站上面还有一些介绍unilm论文和特殊的mask如何实现的文章,可以去网站里搜索一下。http://www.blog.zhxing.online/#/ 搜索unilm 即可。

### 更新记录
2020.10.24: 调整了大量代码,添加了THUCNews数据集的自动摘要例子~现在的话,训练应该效果很好了,以前可能出现预训练参数加载不上的情况,效果有时会很差。

2020.10.23: 调整了一些代码结构,把每个例子里面的一些变量写为全局变量了,改了下beam-search的代码,更精简了。不过暂时不支持写诗里面的押韵了。以后补上。

2020.09.29: 新增了天池医学ner比赛的训练例子(医学ner_train.py),详情可见比赛界面:https://tianchi.aliyun.com/competition/entrance/531824/information
Expand Down
4 changes: 2 additions & 2 deletions bert_seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, word2ix, model_name="roberta"):
raise Exception("model_name_err")

self.hidden_dim = config.hidden_size
self.vocab_size = config.vocab_size
self.vocab_size = len(word2ix)


def compute_loss(self, predictions, labels, target_mask):
Expand Down Expand Up @@ -73,7 +73,7 @@ def forward(self, input_tensor, token_type_id, position_enc=None, labels=None, d
else :
return predictions

def generate(self, text, out_max_length=80, beam_size=1, device="cpu", is_poem=False, max_length=256):
def generate(self, text, out_max_length=40, beam_size=1, device="cpu", is_poem=False, max_length=256):
# 对 一个 句子生成相应的结果
## 通过输出最大长度得到输入的最大长度,这里问题不大,如果超过最大长度会进行截断
self.out_max_length = out_max_length
Expand Down
2 changes: 1 addition & 1 deletion bert_seq2seq/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_chinese_base_vocab(vocab_path, simplfied=False, startswith=["[PAD]", "[
keep_tokens.append(word2idx[t])

print("精简后的词表大小为:" + str(len(keep_tokens)))
return new_token_dict
return new_token_dict, keep_tokens
else:
return word2idx

Expand Down
11 changes: 9 additions & 2 deletions bert_seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ def load_bert(word2ix, model_name="roberta", model_class="seq2seq", target_size=
else :
raise Exception("model_name_err")

def load_model_params(model, pretrain_model_path):
def load_model_params(model, pretrain_model_path, keep_tokens=None):

checkpoint = torch.load(pretrain_model_path)
# 模型刚开始训练的时候, 需要载入预训练的BERT
checkpoint = {k[5:]: v for k, v in checkpoint.items()
if keep_tokens is not None:
## 说明精简词表了,embeedding层也要过滤下
embedding_weight_name = "bert.embeddings.word_embeddings.weight"

checkpoint[embedding_weight_name] = checkpoint[embedding_weight_name][keep_tokens]

# checkpoint = {k[5:]: v for k, v in checkpoint.items()
checkpoint = {k: v for k, v in checkpoint.items()
if k[:4] == "bert" and "pooler" not in k}
model.load_state_dict(checkpoint, strict=False)
torch.cuda.empty_cache()
Expand Down
212 changes: 212 additions & 0 deletions examples/THUCNews自动摘要.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
## THUCNews 原始数据集
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import numpy as np
import os
import json
import time
import glob
import bert_seq2seq
from torch.utils.data import Dataset, DataLoader
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert, load_model_params, load_recent_model

vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
model_name = "roberta" # 选择模型名字
model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # 模型位置
recent_model_path = "./state_dict/bert_auto_title_model.bin" # 用于把已经训练好的模型继续训练
model_save_path = "./state_dict/bert_auto_title_model.bin"
batch_size = 16
lr = 1e-5
maxlen = 256

class BertDataset(Dataset):
"""
针对特定数据集,定义一个相关的取数据的方式
"""
def __init__(self) :
## 一般init函数是加载所有数据
super(BertDataset, self).__init__()
## 拿到所有文件名字
self.txts = glob.glob('./state_dict/THUCNews/*/*.txt')

self.idx2word = {k: v for v, k in word2idx.items()}
self.tokenizer = Tokenizer(word2idx)

def __getitem__(self, i):
## 得到单个数据
# print(i)
text_name = self.txts[i]
with open(text_name, "r", encoding="utf-8") as f:
text = f.read()
text = text.split('\n')
if len(text) > 1:
title = text[0]
content = '\n'.join(text[1:])
token_ids, token_type_ids = self.tokenizer.encode(
content, title, max_length=maxlen
)
output = {
"token_ids": token_ids,
"token_type_ids": token_type_ids,
}
return output

self.__getitem__(i + 1)

def __len__(self):

return len(self.txts)

def collate_fn(batch):
"""
动态padding, batch为一部分sample
"""

def padding(indice, max_length, pad_idx=0):
"""
pad 函数
"""
pad_indice = [item + [pad_idx] * max(0, max_length - len(item)) for item in indice]
return torch.tensor(pad_indice)

token_ids = [data["token_ids"] for data in batch]
max_length = max([len(t) for t in token_ids])
token_type_ids = [data["token_type_ids"] for data in batch]

token_ids_padded = padding(token_ids, max_length)
token_type_ids_padded = padding(token_type_ids, max_length)
target_ids_padded = token_ids_padded[:, 1:].contiguous()

return token_ids_padded, token_type_ids_padded, target_ids_padded

class Trainer:
def __init__(self):
# 判断是否有可用GPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: " + str(self.device))
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name)
## 加载预训练的模型参数~

load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 加载已经训练好的模型,继续训练
# load_recent_model(self.bert_model, recent_model_path)

# 将模型发送到计算设备(GPU或CPU)
self.bert_model.to(self.device)
# 声明需要优化的参数
self.optim_parameters = list(self.bert_model.parameters())
self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
# 声明自定义的数据加载器
dataset = BertDataset()
self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

def train(self, epoch):
# 一个epoch的训练
self.bert_model.train()
self.iteration(epoch, dataloader=self.dataloader, train=True)

def save(self, save_path):
"""
保存模型
"""
torch.save(self.bert_model.state_dict(), save_path)
print("{} saved!".format(save_path))

def iteration(self, epoch, dataloader, train=True):
total_loss = 0
start_time = time.time() ## 得到当前时间
step = 0
report_loss = 0
for token_ids, token_type_ids, target_ids in tqdm(dataloader,position=0, leave=True):
step += 1
if step % 1000 == 0:
self.bert_model.eval()
test_data = ["夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。",
"2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
"8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。"]
for text in test_data:
print(self.bert_model.generate(text, beam_size=3,device=self.device))
print("loss is " + str(report_loss))
report_loss = 0
# self.eval(epoch)
self.bert_model.train()
if step % 8000 == 0:
self.save(model_save_path)

token_ids = token_ids.to(self.device)
token_type_ids = token_type_ids.to(self.device)
target_ids = target_ids.to(self.device)
# 因为传入了target标签,因此会计算loss并且返回
predictions, loss = self.bert_model(token_ids,
token_type_ids,
labels=target_ids,
device=self.device
)
report_loss += loss.item()
# 反向传播
if train:
# 清空之前的梯度
self.optimizer.zero_grad()
# 反向传播, 获取新的梯度
loss.backward()
# 用获取的梯度更新模型参数
self.optimizer.step()

# 为计算当前epoch的平均loss
total_loss += loss.item()

end_time = time.time()
spend_time = end_time - start_time
# 打印训练信息
print("epoch is " + str(epoch)+". loss is " + str(total_loss) + ". spend time is "+ str(spend_time))
# 保存模型
self.save(model_save_path)

if __name__ == '__main__':

# src, tgt = read_file("./data/train.src", "./data/train.tgt")

trainer = Trainer()
train_epoches = 20

for epoch in range(train_epoches):
# 训练一个epoch
trainer.train(epoch)

# 测试一下自定义数据集
# vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
# # sents_src, sents_tgt = read_file("./corpus/auto_title/train.src", "./corpus/auto_title/train.tgt")
# sents_src= torch.load("./corpus/auto_title/train_clean.src")
# sents_tgt = torch.load("./corpus/auto_title/train_clean.tgt")
# import time
# dataset = BertDataset(sents_src, sents_tgt, vocab_path)
# word2idx = load_chinese_base_vocab(vocab_path)
# tokenier = Tokenizer(word2idx)
# dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
# for token_ids, token_type_ids, target_ids in dataloader:
# # print(token_ids.shape)
# print(tokenier.decode(token_ids[0].tolist()))
# print(tokenier.decode(token_ids[1].tolist()))
# # print(token_type_ids)
# # print(target_ids.shape)
# # print(tokenier.decode(target_ids[0].tolist()))
# # print(tokenier.decode(target_ids[1].tolist()))
# break


# src, tgt = read_file("./corpus/auto_title/train.src", "./corpus/auto_title/train.tgt")
# save_src, save_tgt = [], []
# for src_i, tgt_i in zip(src, tgt):
# src_i = src_i.replace("“", "").replace("”", "").replace("——", "-").replace("—", "-")
# tgt_i = tgt_i.replace("“", "").replace("”", "").replace("——", "-").replace("—", "-")

# save_src.append(src_i)
# save_tgt.append(tgt_i)

# torch.save(save_src, "./corpus/auto_title/train_clean.src")
# torch.save(save_tgt, "./corpus/auto_title/train_clean.tgt")
4 changes: 2 additions & 2 deletions examples/auto_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
word2idx = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
model_name = "roberta" # 选择模型名字
model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # 模型位置
recent_model_path = "./state_dict/bert_auto_title_model.bin" # 用于把已经训练好的模型继续训练
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self):
self.bert_model = load_bert(word2idx, model_name=model_name)
## 加载预训练的模型参数~

load_model_params(self.bert_model, model_path)
load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 加载已经训练好的模型,继续训练
# load_recent_model(self.bert_model, self.recent_model_path)

Expand Down
4 changes: 2 additions & 2 deletions examples/三元组抽取_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
batch_size = 16
lr = 1e-5

word2idx = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)

def load_data(filename):
D = []
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self):
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name, model_class="relation_extrac", target_size=len(predicate2id))
## 加载预训练的模型参数~
load_model_params(self.bert_model, model_path)
load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 将模型发送到计算设备(GPU或CPU)
self.bert_model.to(self.device)
# 声明需要优化的参数
Expand Down
4 changes: 2 additions & 2 deletions examples/写诗_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
batch_size = 16
lr = 1e-5

word2idx = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)

def read_corpus(dir_path):
"""
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(self):
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name)
## 加载预训练的模型参数~
load_model_params(self.bert_model, model_path)
load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 将模型发送到计算设备(GPU或CPU)
self.bert_model.to(self.device)
# 声明需要优化的参数
Expand Down
4 changes: 2 additions & 2 deletions examples/医学ner_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
lr = 1e-5
crf_lr = 1e-2 ## crf层学习率为0.01
# 加载字典
word2idx = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)


def from_ann2dic(w_path):
Expand Down Expand Up @@ -280,7 +280,7 @@ def __init__(self):
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name, model_class="sequence_labeling_crf", target_size=len(target))
## 加载预训练的模型参数~
load_model_params(self.bert_model, model_path)
load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 将模型发送到计算设备(GPU或CPU)
self.bert_model.to(self.device)
# 声明需要优化的参数
Expand Down
4 changes: 2 additions & 2 deletions examples/诗词对联_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
model_save_path = "./bert_model_poem_ci_duilian.bin"
batch_size = 8
lr = 1e-5
word2idx = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)

def read_corpus(dir_path):
"""
Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(self):
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name)
## 加载预训练的模型参数~
load_model_params(self.bert_model, model_path)
load_model_params(self.bert_model, model_path, keep_tokens=keep_tokens)
# 将模型发送到计算设备(GPU或CPU)
self.bert_model.to(self.device)
# 声明需要优化的参数
Expand Down
30 changes: 30 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
import sys
sys.path.append("/Users/xingzhaohu/Downloads/code/python/ml/ml_code/bert/bert_seq2seq")
from torch.optim import Adam
import pandas as pd
import numpy as np
import os
import json
import time
import bert_seq2seq
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert, load_model_params, load_recent_model

auto_title_model = "./state_dict/bert_model_poem.bin"

if __name__ == "__main__":
vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
model_name = "roberta" # 选择模型名字
# model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位
# 加载字典
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
# 定义模型
bert_model = load_bert(word2idx, model_name=model_name)
load_model_params(bert_model, "./state_dict/roberta_wwm_pytorch_model.bin", keep_tokens=keep_tokens)

for name, params in bert_model.named_parameters():
print(name)


0 comments on commit 1913fab

Please sign in to comment.