Skip to content

Commit

Permalink
支持调用transformers-bart模型做训练,新增t5、bart摘要例子
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Oct 20, 2021
1 parent 54154e1 commit 5db1f54
Show file tree
Hide file tree
Showing 13 changed files with 549 additions and 28 deletions.
Binary file modified .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ state_dict/bert_seq2seq_save
build/
./test.py
nouse/*
test.py
develop-eggs/
dist/
downloads/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@
2020.04.01: 重构了代码,开始训练一个新的任务花费时间更少。

python setup.py sdist
twine upload dist/bert_seq2seq-2.0.2.tar.gz
twine upload dist/bert_seq2seq-2.1.0.tar.gz
77 changes: 77 additions & 0 deletions bert_seq2seq/bart_chinese.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

import torch
from bert_seq2seq.model.bart_model import BartConfig, BartForConditionalGeneration, BartModel, shift_tokens_right
from bert_seq2seq.tokenizer import Tokenizer,load_chinese_base_vocab
from bert_seq2seq.basic_bert import BasicBart
from bert_seq2seq.seq2seq_model import top_k_top_p_filtering
import torch.nn.functional as F
import torch.nn as nn

class BartGenerationModel(BasicBart):

def __init__(self, word2idx):
super().__init__()
config = BartConfig()
self.config = config
self.model = BartModel(config)
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

self.word2idx = word2idx
self.tokenizer = Tokenizer(self.word2idx)
self.bos_id = self.word2idx["[CLS]"]
self.eos_id = self.word2idx["[SEP]"]
self.unk_id = self.word2idx["[UNK]"]

def forward(self, input_ids, decoder_input_ids, labels=None):
input_ids = input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
if labels is not None:
labels = labels.to(self.device)
if labels is not None:
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)

decoder_out, _ = self.model(
input_ids,
decoder_input_ids=decoder_input_ids,
)

lm_logits = self.lm_head(decoder_out)
target_mask = (decoder_input_ids > 0).float().view(-1)
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
masked_lm_loss = (loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) * target_mask).sum() / target_mask.sum()

output = (lm_logits,)
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output


def sample_generate_encoder_decoder(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=True):

token_out = self.tokenizer.encode(text, max_length=input_max_length)
if len(token_out) == 2:
token_ids = token_out[0]
else:
token_ids = token_out
if not add_eos:
token_ids = token_ids[:-1]
token_ids = torch.tensor(token_ids, device=self.device, dtype=torch.long).view(1, -1)
output_ids = []

input_decoder_ids = torch.tensor(self.bos_id, device=self.device, dtype=torch.long).view(1, -1)
with torch.no_grad():
for step in range(out_max_length):
scores = self.model(input_ids=token_ids, decoder_input_ids=input_decoder_ids)[0]
logit_score = torch.log_softmax(scores[:, -1], dim=-1).squeeze(0)
logit_score[self.unk_id] = -float('Inf')
filtered_logits = top_k_top_p_filtering(logit_score, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if self.eos_id == next_token.item():
break
output_ids.append(next_token.item())
input_decoder_ids = torch.cat((input_decoder_ids, next_token.long().unsqueeze(0)), dim=1)

return self.tokenizer.decode(output_ids)
29 changes: 29 additions & 0 deletions bert_seq2seq/basic_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,32 @@ def set_device(self, device):
def save_all_params(self, save_path):
torch.save(self.state_dict(), save_path)

class BasicBart(nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device("cpu")

def load_pretrain_params(self, pretrain_model_path):
checkpoint = torch.load(pretrain_model_path, map_location=self.device)
checkpoint = {"model." + k: v for k, v in checkpoint.items()}

self.load_state_dict(checkpoint, strict=False)
torch.cuda.empty_cache()
print("{} loaded!".format(pretrain_model_path))

def load_all_params(self, model_path, device="cuda"):
checkpoint = torch.load(model_path, map_location=device)
self.load_state_dict(checkpoint, strict=False)
torch.cuda.empty_cache()
print(str(model_path) + " loaded!")

def forward(self, x):
raise NotImplemented

def set_device(self, device):
self.device = torch.device(device)
self.to(device)

def save_all_params(self, save_path):
torch.save(self.state_dict(), save_path)

22 changes: 21 additions & 1 deletion bert_seq2seq/extend_model_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy as np
from bert_seq2seq.seq2seq_model import top_k_top_p_filtering
import torch.nn.functional as F
import torch.nn as nn


class ExtendModel:
def __init__(self, model, tokenizer, bos_id, eos_id, device="cpu") -> None:
def __init__(self, model:nn.Module, tokenizer, bos_id, eos_id, device="cpu") -> None:
self.model = model
self.tokenizer = tokenizer
self.device = device
Expand All @@ -15,6 +17,24 @@ def __init__(self, model, tokenizer, bos_id, eos_id, device="cpu") -> None:
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)

def to(self, device):
self.model.to(device)

def load_state_dict(self, model_param, strict=True):
self.model.load_state_dict(model_param, strict=strict)

def state_dict(self):
return self.model.state_dict()

def train(self):
self.model.train()

def eval(self):
self.model.eval()

def parameters(self):
return self.model.parameters()

def sample_generate_autoregressive(self, text, input_max_length=256, out_max_length=200, top_k=30, top_p=0.0, add_eos=False):

token_ids = self.tokenizer.encode(text, max_length=input_max_length, truncation=True)
Expand Down
39 changes: 16 additions & 23 deletions bert_seq2seq/model/bart_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ class BartConfig():

def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
vocab_size=21128,
max_position_embeddings=512,
encoder_layers=6,
encoder_ffn_dim=3072,
encoder_attention_heads=12,
decoder_layers=6,
decoder_ffn_dim=3072,
decoder_attention_heads=12,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
d_model=768,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
Expand All @@ -167,11 +167,11 @@ def __init__(
force_bos_token_to_be_generated=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
pad_token_id=0,
bos_token_id=101,
eos_token_id=102,
is_encoder_decoder=True,
decoder_start_token_id=2,
decoder_start_token_id=102,
):
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
Expand Down Expand Up @@ -899,6 +899,7 @@ def forward(
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
device = input_ids.device
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -931,6 +932,7 @@ def forward(
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
)
combined_attention_mask = combined_attention_mask.to(device)

if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down Expand Up @@ -1179,13 +1181,4 @@ def forward(
# out_lm = out[1]
# print(out_lm.shape)
names = []
for name, _ in model.named_parameters():
print(name)
names.append(name)

print("~~~~~~~~~")
checkpoint = torch.load("./state_dict/bart_model.bin", map_location="cpu")
ks = []
for k, v in checkpoint.items():
print(k)
ks.append(k)

2 changes: 1 addition & 1 deletion bert_seq2seq/model/nezha_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
max_position_embeddings=1024,
max_relative_position=64,
type_vocab_size=2,
initializer_range=0.02,
Expand Down
Loading

0 comments on commit 5db1f54

Please sign in to comment.