Skip to content

Commit

Permalink
精简代码
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Aug 18, 2021
1 parent 6e3dd36 commit b5683be
Show file tree
Hide file tree
Showing 16 changed files with 30 additions and 34 deletions.
8 changes: 3 additions & 5 deletions test/auto_title_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
if __name__ == "__main__":
vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置
model_name = "roberta" # 选择模型名字
# model_path = "./state_dict/bert-base-chinese-pytorch_model.bin" # roberta模型位
# 加载字典
word2idx, keep_tokens = load_chinese_base_vocab(vocab_path, simplfied=True)
# 定义模型
Expand All @@ -18,10 +17,9 @@
## 加载训练的模型参数~
bert_model.load_all_params(model_path=auto_title_model, device=device)

# test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
# "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
# "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
test_data = ["重庆潼南县的8位村民一年前在河道里挖出一根30米长乌木,卖得19.6万元,大家分了这笔数额不小的意外之财。如今,当地财政局将他们起诉到法院,称乌木在河道中发现,其所有权应归国家。法院一审二审都判决村民们还钱"]
test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
"楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
"新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
for text in test_data:
with torch.no_grad():
print(bert_model.generate(text, beam_size=3))
Expand Down
2 changes: 1 addition & 1 deletion test/bert_english_autotitle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import glob
import json
from rouge import Rouge
from bert_seq2seq.utils import load_bert
from bert_seq2seq import load_bert
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
Expand Down
4 changes: 2 additions & 2 deletions test/gpt_ancient_translation_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import torch
from bert_seq2seq.utils import load_gpt
from bert_seq2seq.tokenizer import load_chinese_base_vocab
from bert_seq2seq import load_gpt
from bert_seq2seq import load_chinese_base_vocab

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
4 changes: 2 additions & 2 deletions test/gpt_article_continued_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import torch
from bert_seq2seq.utils import load_gpt
from bert_seq2seq.tokenizer import load_chinese_base_vocab
from bert_seq2seq import load_gpt
from bert_seq2seq import load_chinese_base_vocab

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion test/gpt_english_story_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import torch
from bert_seq2seq.utils import load_gpt
from bert_seq2seq import load_gpt
from transformers import AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
4 changes: 2 additions & 2 deletions test/gpt_explain_dream_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import torch
from bert_seq2seq.utils import load_gpt
from bert_seq2seq.tokenizer import load_chinese_base_vocab
from bert_seq2seq import load_gpt
from bert_seq2seq import load_chinese_base_vocab

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion test/gpt_test_english.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import torch
from bert_seq2seq.utils import load_gpt
from bert_seq2seq import load_gpt
from transformers import AutoTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
5 changes: 2 additions & 3 deletions test/nezha_auto_title_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

auto_title_model = "./state_dict/nezha_auto_title.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -20,7 +20,6 @@
test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
"楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
"新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
# test_data = ["重庆潼南县的8位村民一年前在河道里挖出一根30米长乌木,卖得19.6万元,大家分了这笔数额不小的意外之财。如今,当地财政局将他们起诉到法院,称乌木在河道中发现,其所有权应归国家。法院一审二审都判决村民们还钱"]
for text in test_data:
with torch.no_grad():
print(bert_model.generate(text, beam_size=3))
Expand Down
4 changes: 2 additions & 2 deletions test/nezha_relation_extract_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import numpy as np
import json
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

relation_extrac_model = "./state_dict/nezha_relation_extract.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
5 changes: 2 additions & 3 deletions test/poem_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

auto_title_model = "./state_dict/bert_model_poem_ci_duilian.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -25,7 +25,6 @@
if text[-1] == "句" or text[-1] == "诗":
print(bert_model.generate(text, beam_size=3, is_poem=True))
else:
# print(bert_model.generate_random(text, beam_size=5))
print(bert_model.generate(text, beam_size=3, is_poem=False))


Expand Down
4 changes: 2 additions & 2 deletions test/relation_extract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np
import json
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

relation_extrac_model = "./state_dict/bert_model_relation_extrac.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
4 changes: 2 additions & 2 deletions test/semantic_matching_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

target = ["0", "1"]

Expand Down
4 changes: 2 additions & 2 deletions test/做数学题_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

vocab_path = "./state_dict/roberta_wwm_vocab.txt" # roberta模型字典的位置

Expand Down
4 changes: 2 additions & 2 deletions test/新闻标题文本分类_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

target = ["财经", "彩票", "房产", "股票", "家居", "教育", "科技", "社会", "时尚", "时政", "体育", "星座", "游戏", "娱乐"]

Expand Down
4 changes: 2 additions & 2 deletions test/粗粒度ner_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

target = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"]

Expand Down
4 changes: 2 additions & 2 deletions test/细粒度ner_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from bert_seq2seq.utils import load_bert
from bert_seq2seq import Tokenizer, load_chinese_base_vocab
from bert_seq2seq import load_bert

target = ["other", "address", "book", "company", "game", "government", "movie", "name", "organization", "position", "scene"]

Expand Down

0 comments on commit b5683be

Please sign in to comment.