Skip to content

Commit

Permalink
修复部分bug、添加bert对句子进行embedding的代码
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Dec 15, 2021
1 parent caaed4a commit 15b635f
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 20 deletions.
24 changes: 20 additions & 4 deletions bert_seq2seq/basic_bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import torch
import torch.nn as nn
from bert_seq2seq.tokenizer import Tokenizer

def get_model(model_name, word2ix):
if model_name == "roberta":
Expand Down Expand Up @@ -38,15 +39,20 @@ def get_model(model_name, word2ix):
return config, bert, layer_norm_cond, CLS

class BasicBert(nn.Module):
def __init__(self, word2ix, model_name="roberta"):
def __init__(self, word2ix, model_name="roberta", tokenizer=None):
super().__init__()
self.config = ""
self.word2ix = word2ix

if tokenizer is None:
self.tokenizer = Tokenizer(word2ix)
else:
self.tokenizer = tokenizer

self.model_name = model_name

self.config, self.bert, self.layer_norm_cond, self.cls = get_model(model_name, word2ix)


self.device = torch.device("cpu")

def load_pretrain_params(self, pretrain_model_path, keep_tokens=None, strict=False):
Expand Down Expand Up @@ -74,8 +80,18 @@ def load_all_params(self, model_path, device="cuda"):
torch.cuda.empty_cache()
print(str(model_path) + " loaded!")

def forward(self, x):
raise NotImplemented
def forward(self, input_text):
## 返回bert编码后得到的向量
input_ids, _ = self.tokenizer.encode(input_text, max_length=512)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=self.device).view(1, -1)

enc_layers, _ = self.bert(input_ids, position_ids=None, token_type_ids=None,
output_all_encoded_layers=True)
squence_out = enc_layers[-1] ## 取出来最后一层输出 (batch, seq_len, 768)

tokens_hidden_state, _ = self.cls(squence_out)

return tokens_hidden_state

def set_device(self, device):
self.device = torch.device(device)
Expand Down
1 change: 1 addition & 0 deletions bert_seq2seq/bert_cls_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def compute_loss(self, predictions, labels):
def forward(self, text, position_enc=None, labels=None, use_layer_num=-1):
if use_layer_num != -1:
raise Exception("暂时只支持用最后一层进行分类")

text = text.to(self.device)
if position_enc is not None:
position_enc = position_enc.to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion bert_seq2seq/gpt2_generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, word2ix, tokenizer=None,
self.model = GPT2LMHeadModel(self.config)

def sample_generate(self, text, input_max_length=256, out_max_length=200,
top_k=30, top_p=0.0, add_eos=False, repetition_penalty=1.0,
top_k=30, top_p=1.0, add_eos=False, repetition_penalty=1.0,
temperature=1.0):

lp = [RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty),
Expand Down
3 changes: 2 additions & 1 deletion bert_seq2seq/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import nn
from torch._C import device
from torch.nn import CrossEntropyLoss, MSELoss

def swish(x):
Expand Down Expand Up @@ -464,7 +465,7 @@ def forward(
extended_attention_mask = attention_mask * extended_attention_mask

if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
Expand Down
2 changes: 1 addition & 1 deletion bert_seq2seq/model/roberta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def forward(
extended_attention_mask = attention_mask * extended_attention_mask

if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
token_type_ids = torch.zeros_like(input_ids, dtype=torch.long, device=input_ids.device)

# print(extended_attention_mask)
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
Expand Down
7 changes: 1 addition & 6 deletions bert_seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@ class Seq2SeqModel(BasicBert):
"""
"""
def __init__(self, word2ix, model_name="roberta", tokenizer=None):
super(Seq2SeqModel, self).__init__(word2ix=word2ix, model_name=model_name)
self.word2ix = word2ix
if tokenizer is None:
self.tokenizer = Tokenizer(word2ix)
else:
self.tokenizer = tokenizer
super(Seq2SeqModel, self).__init__(word2ix=word2ix, model_name=model_name, tokenizer=tokenizer)

self.hidden_dim = self.config.hidden_size
self.vocab_size = len(word2ix)
Expand Down
7 changes: 1 addition & 6 deletions bert_seq2seq/simbert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,8 @@ class SimBertModel(BasicBert):
"""
"""
def __init__(self, word2ix, model_name="roberta", tokenizer=None):
super(SimBertModel, self).__init__(word2ix=word2ix, model_name=model_name)
super(SimBertModel, self).__init__(word2ix=word2ix, model_name=model_name, tokenizer=tokenizer)
self.word2ix = word2ix
if tokenizer is None:
self.tokenizer = Tokenizer(word2ix)
else:
self.tokenizer = tokenizer

self.hidden_dim = self.config.hidden_size
self.vocab_size = len(word2ix)

Expand Down
6 changes: 5 additions & 1 deletion bert_seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bert_seq2seq.bert_cls_multi_seq2seq import ClsMultiSeq2SeqModel
from bert_seq2seq.simbert_model import SimBertModel
from bert_seq2seq.gpt2_generate_model import GPT2
from bert_seq2seq.basic_bert import BasicBert


def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2seq", target_size=0, target=None):
Expand Down Expand Up @@ -44,14 +45,17 @@ def load_bert(word2ix, tokenizer=None, model_name="roberta", model_class="seq2se
bert_model = BertRelationExtrac(word2ix, target_size, model_name=model_name)
return bert_model
elif model_class == "simbert":
bert_model = SimBertModel(word2ix, model_name=model_name)
bert_model = SimBertModel(word2ix, model_name=model_name, tokenizer=tokenizer)
return bert_model
elif model_class == "multi_label_cls":
bert_model = BertClsMultiClassifier(word2ix, target_size, model_name=model_name)
return bert_model
elif model_class == "multi_label_cls_seq2seq":
bert_model = ClsMultiSeq2SeqModel(word2ix, target, model_name=model_name)
return bert_model
elif model_class == "embedding":
bert_model = BasicBert(word2ix, model_name=model_name, tokenizer=tokenizer)
return bert_model
else :
raise Exception("model_name_err")

Expand Down
31 changes: 31 additions & 0 deletions test/get_bert_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
## 使用bert对一个句子进行编码

import torch
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

model_path = "./state_dict/roberta_wwm_pytorch_model.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if __name__ == "__main__":
vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
model_name = "roberta" # 选择模型名字
# 加载字典
word2idx = load_chinese_base_vocab(vocab_path)
# 定义模型
bert_model = load_bert(word2idx, model_name=model_name, model_class="embedding")
bert_model.set_device(device)
bert_model.eval()
## 加载训练的模型参数~
bert_model.load_pretrain_params(model_path)

test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
"楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
"新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
for text in test_data:
with torch.no_grad():
print(bert_model(text).shape)
print("\n")



0 comments on commit 15b635f

Please sign in to comment.