Skip to content

Commit

Permalink
优化了一些代码,添加了几个测试例子。
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Dec 2, 2020
1 parent 808090e commit a506242
Show file tree
Hide file tree
Showing 13 changed files with 496 additions and 65 deletions.
Binary file modified .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ pytorch实现bert做seq2seq任务,使用unilm方案。如果喜欢的话欢迎
多谢支持。另外,网站上面还有一些介绍unilm论文和特殊的mask如何实现的文章,可以去网站里搜索一下。http://www.blog.zhxing.online/#/ 搜索unilm 即可。

### 更新记录

2020.12.02: 调整了一些代码,并且添加了几个测试的文件,可以很方便的加载已经训练好的模型,进行对应任务的测试。

2020.11.20: 添加了一个例子,三元组抽取f1目前能到0.7。添加了新闻摘要文本分类的测试代码。

2020.11.04: 跑了跑bert-crf做普通ner任务的例子,效果不错。
Expand Down
4 changes: 2 additions & 2 deletions bert_seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def beam_search(self, token_ids, token_type_ids, word2ix, beam_size=1, device="c
best_one = output_scores.argmax()
if end_counts[best_one] == 1:
# 说明出现终止了~
return output_ids[best_one]
return output_ids[best_one][:-1]
else :
# 保留未完成部分
flag = (end_counts < 1) # 标记未完成序列
Expand Down Expand Up @@ -376,7 +376,7 @@ def beam_search_poem(self, text, token_ids, token_type_ids, word2ix, beam_size=1
# 说明出现终止了~
# print(repeat_word)
# print(yayun_chars)
return output_ids[best_one]
return output_ids[best_one][:-1]
else :
# 保留未完成部分
flag = (end_counts < 1) # 标记未完成序列
Expand Down
137 changes: 100 additions & 37 deletions examples/粗粒度NER_CRF_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
import unicodedata
import pandas as pd
import numpy as np
import os
Expand All @@ -19,14 +20,71 @@
model_name = "roberta" # 选择模型名字
model_path = "./state_dict/roberta_wwm_pytorch_model.bin" # roberta模型位置
recent_model_path = "" # 用于把已经训练好的模型继续训练
model_save_path = "./bert_ner_model_crf.bin"
model_save_path = "./bert_粗粒度ner_crf.bin"
batch_size = 4
lr = 1e-5

word2idx = load_chinese_base_vocab(vocab_path)

target = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"]

def _is_punctuation(ch):
"""标点符号类字符判断(全/半角均在此内)
"""
code = ord(ch)
return 33 <= code <= 47 or \
58 <= code <= 64 or \
91 <= code <= 96 or \
123 <= code <= 126 or \
unicodedata.category(ch).startswith('P')

def _cjk_punctuation():
return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\xb7\uff01\uff1f\uff61\u3002'

def _is_cjk_character(ch):
"""CJK类字符判断(包括中文字符也在此列)
参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
"""
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF or \
0x3400 <= code <= 0x4DBF or \
0x20000 <= code <= 0x2A6DF or \
0x2A700 <= code <= 0x2B73F or \
0x2B740 <= code <= 0x2B81F or \
0x2B820 <= code <= 0x2CEAF or \
0xF900 <= code <= 0xFAFF or \
0x2F800 <= code <= 0x2FA1F

@staticmethod
def _is_control(ch):
"""控制类字符判断
"""
return unicodedata.category(ch) in ('Cc', 'Cf')

def word_piece_tokenize(word):
"""word内分成subword
"""
if word in word2idx:
return [word]

tokens = []
start, stop = 0, 0
while start < len(word):
stop = len(word)
while stop > start:
sub = word[start:stop]
if start > 0:
sub = '##' + sub
if sub in word2idx:
break
stop -= 1
if start == stop:
stop += 1
tokens.append(sub)
start = stop

return tokens

def read_corpus(data_path):
"""
读原始数据
Expand All @@ -37,22 +95,23 @@ def read_corpus(data_path):
with open(data_path) as f:
lines = f.readlines()
row = ""
t = [0]
t = []
for line in lines:
if line == "\n":
t.append(0)
if len(row) < 500:

if len(row) < 300:
sents_src.append(row)
sents_tgt.append(t)
row = ""
t = [0]
t = []
continue
line = line.split(" ")
row = row + line[0]
t.append(target.index(line[1].strip("\n")))
t.append(line[1].strip("\n"))

return sents_src, sents_tgt


## 自定义dataset
class NERDataset(Dataset):
"""
Expand All @@ -74,7 +133,12 @@ def __getitem__(self, i):
# print(i)
src = self.sents_src[i]
tgt = self.sents_tgt[i]
tgt = ["O"] + tgt + ["O"]
tgt = [target.index(i) for i in tgt ]
token_ids, token_type_ids = self.tokenizer.encode(src)
if len(token_ids) != len(tgt):
print("not equal")
os._exit(0)
output = {
"token_ids": token_ids,
"token_type_ids": token_type_ids,
Expand Down Expand Up @@ -176,7 +240,7 @@ def ner_print(model, test_data, device="cpu"):
class Trainer:
def __init__(self):
# 加载数据
data_path = "./corpus/粗粒度NER/example.train"
data_path = "./state_dict/corase_train_update.txt"
self.sents_src, self.sents_tgt = read_corpus(data_path)

self.tokenier = Tokenizer(word2idx)
Expand Down Expand Up @@ -216,7 +280,7 @@ def iteration(self, epoch, dataloader, train=True):
# print(target_ids.shape)
step += 1
if step % 500 == 0:
test_data = ["日寇在京掠夺文物详情。", "以书结缘,把欧美,港台流行的食品类食谱汇集一堂"]
test_data = ["日寇在京掠夺文物详情。", "以书结缘,把欧美,港台流行的食品类食谱汇集一堂。", "明天天津下雨,不知道主任还能不能来学校吃个饭。"]
ner_print(self.bert_model, test_data, device=self.device)
self.bert_model.train()

Expand Down Expand Up @@ -254,32 +318,31 @@ def iteration(self, epoch, dataloader, train=True):
# 训练一个epoch
trainer.train(epoch)

# 测试一下自定义数据集
# vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
# sents_src, sents_tgt = read_corpus("./corpus/粗粒度NER/example.train")
# # print(sents_src)
# print(len(sents_src))
# print(len(sents_src) / 8)
# dataset = NERDataset(sents_src, sents_tgt, vocab_path)
# word2idx = load_chinese_base_vocab(vocab_path)
# tokenier = Tokenizer(word2idx)

# dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
# for token_ids, token_type_ids, target_ids in dataloader:


# # print(token_ids.shape)
# print(tokenier.decode(token_ids[0].tolist()))
# print(tokenier.decode(token_ids[1].tolist()))
# # print(token_type_ids)
# print(target_ids)
# break


# bert_model = load_bert(vocab_path, model_class="encoder", target_size=14)
# bert_model(token_ids)

# print(tokenier.decode(target_ids[0].tolist()))
# print(tokenier.decode(target_ids[1].tolist()))
# break

# with open("./state_dict/corase_train_update.txt", "a+") as f:
# with open("./state_dict/人民日报ner数据.txt", "r", encoding="utf-8") as f1 :
# lines = f1.readlines()
# start = 1
# string = ""
# label = ""
# for line in lines:
# if line == "\n":
# f.write("\n")
# continue
# line = line.strip("\n")
# line = line.split(" ")
# if _is_punctuation(line[0]) or _is_cjk_character(line[0]):
# if string != "":
# string = string.lower()
# tokens = word_piece_tokenize(string) # 子词
# for t in tokens:
# if "##" in t:
# f.write(t[2:] + " " + label + "\n")
# else :
# f.write(t + " " + label + "\n")
# # f.write(string + " " + label + "\n")
# string = ""
# label = ""
# f.write(line[0] + " " + line[1] + "\n")
# else :
# string += line[0]
# label = line[1]
2 changes: 1 addition & 1 deletion examples/细粒度NER_CRF_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 自动写诗的例子

import sys
sys.path.append("/Users/xingzhaohu/Downloads/code/python/ml/ml_code/bert/bert_seq2seq")
import torch
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='bert_seq2seq',
version='1.0.3',
version='1.0.5',
description='use torch to do bert_seq2seq task',
long_description='bert_seq2seq: https://github.com/920232796/bert_seq2seq',
license='Apache License 2.0',
Expand Down
16 changes: 8 additions & 8 deletions test/auto_title_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert, load_model_params, load_recent_model

auto_title_model = "./state_dict/bert_auto_title_model.bin"
auto_title_model = "./state_dict/bert_auto_title_model2.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
Expand All @@ -26,14 +26,14 @@
bert_model.eval()
## 加载训练的模型参数~
load_recent_model(bert_model, recent_model_path=auto_title_model, device=device)
test_data = ["针对央视3·15晚会曝光的电信行业乱象工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理"]
# # test_data = [
# # "本文总结了十个可穿戴产品的设计原则而这些原则同样也是笔者认为是这个行业最吸引人的地方1为人们解决重复性问题2从人开始而不是从机器开始3要引起注意但不要刻意4提升用户能力而不是取代人",
# # "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
# # "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元",
# # "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]

test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
"楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
"新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]

for text in test_data:
print(bert_model.generate(text, beam_size=3))
with torch.no_grad():
print(bert_model.generate(text, beam_size=3))



22 changes: 10 additions & 12 deletions test/poem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,30 @@
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert, load_model_params, load_recent_model

auto_title_model = "./state_dict/bert_model_poem.bin"
auto_title_model = "./state_dict/bert_model_poem_ci_duilian.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
model_name = "roberta" # 选择模型名字
# model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位
# 加载字典
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
# 定义模型
bert_model = load_bert(word2idx, model_name=model_name)
bert_model.eval()
# ## 加载预训练的模型参数~
checkpoint = torch.load(auto_title_model, map_location="cpu")
# print(checkpoint)
load_recent_model(bert_model, recent_model_path=auto_title_model, device=device)
# bert_model.load_state_dict(torch.load(auto_title_model, map_location="cpu"), strict=False)
test_data = ["天涯海角##七言绝句"]
# # test_data = [
# # "本文总结了十个可穿戴产品的设计原则而这些原则同样也是笔者认为是这个行业最吸引人的地方1为人们解决重复性问题2从人开始而不是从机器开始3要引起注意但不要刻意4提升用户能力而不是取代人",
# # "2007年乔布斯向人们展示iPhone并宣称它将会改变世界还有人认为他在夸大其词然而在8年后以iPhone为代表的触屏智能手机已经席卷全球各个角落未来智能手机将会成为真正的个人电脑为人类发展做出更大的贡献",
# # "雅虎发布2014年第四季度财报并推出了免税方式剥离其持有的阿里巴巴集团15%股权的计划打算将这一价值约400亿美元的宝贵投资分配给股东截止发稿前雅虎股价上涨了大约7%至5145美元",
# # "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
for text in test_data:
print(bert_model.generate(text, beam_size=3, is_poem=True))
test_data = ["江山竞秀,万里风光入画图##对联",
"北国风光##五言绝句"]
with torch.no_grad():
for text in test_data:
if text[-1] == "句" or text[-1] == "诗":
print(bert_model.generate(text, beam_size=3, is_poem=True))
else:
print(bert_model.generate(text, beam_size=3, is_poem=False))

# print(name[0])


Loading

0 comments on commit a506242

Please sign in to comment.