Skip to content

Commit

Permalink
simbert
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Dec 9, 2021
1 parent aad46c7 commit caaed4a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 26 deletions.
Binary file modified .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@
2020.04.01: 重构了代码,开始训练一个新的任务花费时间更少。

python setup.py sdist
twine upload dist/bert_seq2seq-2.3.2.tar.gz
twine upload dist/bert_seq2seq-2.3.3.tar.gz
71 changes: 47 additions & 24 deletions examples/simbert_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from torch.utils.data import Dataset, DataLoader
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

data_path = "./corpus/相似句/simtrain_to05sts.txt"
import random

data_path = "./sim_ques.txt"
vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
word2idx = load_chinese_base_vocab(vocab_path, simplfied=False)
tokenizer = Tokenizer(word2idx)
Expand All @@ -18,35 +19,50 @@
maxlen = 256

def read_corpus():
data = []
tmp = []
with open(data_path, encoding="utf-8") as f:
lines = f.readlines()
inputs = []
outputs = []

for line in lines :
print(line)
line = line.split("\t")
inputs.append(line[1])
outputs.append(line[3])
return inputs, outputs
# print(line)
line = line.replace("“", "").replace("”", "").replace("\n", "").strip("\t")
if line not in tmp:
tmp.append(line)
line = line.split("\t")
data.append(line)

return data

data = read_corpus()

print(f"数据共{len(data)}条")

class BertDataset(Dataset):
"""
针对特定数据集,定义一个相关的取数据的方式
"""

def __init__(self, inputs, outputs):
def __init__(self):
## 一般init函数是加载所有数据
super(BertDataset, self).__init__()
self.inputs = inputs
self.outputs = outputs
self.idx2word = {k: v for v, k in word2idx.items()}


def __getitem__(self, i):
## 得到单个数据
# print(i)
inp = self.inputs[i]
out = self.outputs[i]
d = data[i]
if len(d) < 2:
return self.__getitem__(i + 1)

d_list = [i for i in range(len(d))]
random_1 = random.choice(d_list)
d_list.remove(random_1)
random_2 = random.choice(d_list)
inp = data[i][random_1]
out = data[i][random_2]

token_ids_1, token_type_ids_1 = tokenizer.encode(
inp, out, max_length=maxlen
)
Expand All @@ -65,7 +81,7 @@ def __getitem__(self, i):


def __len__(self):
return len(self.inputs)
return len(data)


def collate_fn(batch):
Expand Down Expand Up @@ -99,22 +115,22 @@ def padding(indice, max_length, pad_idx=0):
class Trainer:
def __init__(self):
# 判断是否有可用GPU
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print("device: " + str(self.device))
# 定义模型
self.bert_model = load_bert(word2idx, model_name=model_name, model_class="simbert")
## 加载预训练的模型参数~
inputs, outputs = read_corpus()
self.bert_model.load_pretrain_params(model_path)
# 加载已经训练好的模型,继续训练
# self.bert_model.load_all_params(model_save_path, device=self.device)

# 将模型发送到计算设备(GPU或CPU)
self.bert_model.set_device(self.device)
# 声明需要优化的参数
self.optim_parameters = list(self.bert_model.parameters())
self.optimizer = torch.optim.Adam(self.optim_parameters, lr=lr, weight_decay=1e-3)
# 声明自定义的数据加载器
dataset = BertDataset(inputs, outputs)
dataset = BertDataset()
self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

def train(self, epoch):
Expand All @@ -136,21 +152,28 @@ def iteration(self, epoch, dataloader, train=True):
report_loss = 0
for token_ids, token_type_ids, target_ids in tqdm(dataloader, position=0, leave=True):
step += 1
if step % 1000 == 0:
# for t in token_ids:
# print(tokenizer.decode(t.cpu().numpy()))
# break
if step % 100 == 0:
self.bert_model.eval()
test_data = [
"他这个人没个正经的。",
"咱俩谁跟谁呀。"
"微信支付宝哪个好用?",
"今年冬天会不会特别冷?"
]
for text in test_data:
print(self.bert_model.sample_generate(text))
print(self.bert_model.sample_generate(text))
print(self.bert_model.sample_generate(text))
print(self.bert_model.sample_generate(text,
out_max_length=40,
top_k=30, top_p=0.7,
repetition_penalty=1.5,
temperature=1.5, sample_num=8))

print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print("loss is " + str(report_loss))
report_loss = 0
# self.eval(epoch)
self.bert_model.train()
self.save(model_save_path)
if step % 5000 == 0:
self.save(model_save_path)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='bert_seq2seq',
version='2.3.2',
version='2.3.3',
description='use torch to do bert_seq2seq task',
long_description='bert_seq2seq: https://github.com/920232796/bert_seq2seq',
license='Apache License 2.0',
Expand Down

0 comments on commit caaed4a

Please sign in to comment.