Skip to content

Commit

Permalink
readme修改
Browse files Browse the repository at this point in the history
  • Loading branch information
qukequke committed Mar 6, 2022
1 parent d4bb196 commit bfd3914
Show file tree
Hide file tree
Showing 22 changed files with 33 additions and 19 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ bert_crf_token_classification_raw/
2、在config.py中修改dir_name

# bert 训练曲线(20epoch)
![image](data/med_data/dev_f1.png)
![image](data/med_data/loss.png)

model | acc | loss
:-------------------------:|:-------------------------:|:-------------------------:
bert | ![](data/cner/bert_dev_f1.png) | ![](data/cner/bert_loss.png)
roberta | ![](data/cner/roberta_dev_f1.png) | ![](data/cner/roberta_loss.png)
ernie | ![](data/cner/ernie_dev_f1.png) | ![](data/cner/ernie_loss.png)
albert | ![](data/cner/albert_dev_f1.png) | ![](data/cner/albert_loss.png)

# 具体参数可看config.py
```
Expand Down
20 changes: 15 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@
'transformers.BertConfig',
'bert-base-chinese' # 使用模型
),

'ernie': (
'transformers.AutoTokenizer',
'transformers.BertModel',
'transformers.AutoConfig',
"nghuyong/ernie-1.0", # 使用模型参数
),
'roberta': (
'transformers.BertTokenizer',
'transformers.RobertaModel',
'transformers.RobertaConfig',
'hfl/chinese-roberta-wwm-ext',
),
'albert': ('transformers.BertTokenizer',
'albert': ('transformers.AutoTokenizer',
'transformers.AlbertModel',
'transformers.AlbertConfig'
'transformers.AutoConfig',
"voidful/albert_chinese_tiny", # 使用模型参数
),
}
MODEL = 'roberta'
# MODEL = 'ernie'
# MODEL = 'albert'
# MODEL = 'bert'

epochs = 20
batch_size = 32
# batch_size = 1
Expand All @@ -46,7 +57,7 @@
# 切换任务时 数据配置
csv_rows = ['raw_sen', 'label'] # csv的行标题,文本 和 类(目前类必须是列表)

dir_name = 'med_data'
dir_name = 'cner'
train_file = f"data/{dir_name}/train.csv"
dev_file = f"data/{dir_name}/dev.csv"
# dev_file = f"data/{dir_name}/train.csv"
Expand All @@ -59,7 +70,6 @@
test_pred_out = f"data/{dir_name}/test_data_predict.csv"
# csv_encoding = 'gbk'

MODEL = 'bert'

PREFIX = ''
# max_src_length = 400
Expand Down
Binary file added data/cner/albert_dev_f1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cner/albert_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
Binary file added data/cner/ernie_dev_f1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cner/ernie_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
Binary file added data/cner/roberta_dev_f1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/cner/roberta_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions data/txt2csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def change(file_name, file_out, dict_file, split_=' ', split_3=False):

if __name__ == '__main__':
json_dict = 'med_data/label_2_id.json'
get_dict('med_data/train.char.bmes', json_dict) # 生成label2id字典,存储到json_dict
change('med_data/train.char.bmes', 'med_data/train.csv', json_dict)
change('med_data/test.char.bmes', 'med_data/test.csv', json_dict)
change('med_data/dev.char.bmes', 'med_data/dev.csv', json_dict)
get_dict('cner/train.char.bmes', json_dict) # 生成label2id字典,存储到json_dict
change('cner/train.char.bmes', 'cner/train.csv', json_dict)
change('cner/test.char.bmes', 'cner/test.csv', json_dict)
change('cner/dev.char.bmes', 'cner/dev.csv', json_dict)
3 changes: 1 addition & 2 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from tqdm import tqdm
from transformers import BertTokenizer

Expand All @@ -24,7 +23,7 @@

label_2_id = load_json(json_dict)

dir_name = 'med_data'
dir_name = 'cner'
target_file = f'models/{dir_name}/best.pth.tar' # 模型存储路径
label_file = f'data/{dir_name}/label2id.json'
bert_path_or_name = 'bert-base-chinese' # 使用模型
Expand Down
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# -*- coding: utf-8 -*-
import torch
from torch import nn
# from transformers import BertForSequenceClassification, BertConfig, RobertaForSequenceClassification, RobertaConfig, \
# AutoModelForSequenceClassification, AutoConfig
from config import *
from utils import eval_object

Expand Down Expand Up @@ -435,7 +432,10 @@ def __init__(self):
self.device = torch.device("cuda")
# self.dropout = nn.Dropout(hidden_dropout_prob)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(768, num_labels)
if MODEL == 'albert':
self.classifier = nn.Linear(312, num_labels)
else:
self.classifier = nn.Linear(768, num_labels)
self.crf = CRF(num_tags=num_labels, batch_first=True)

def forward(self, **input_):
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,8 @@ def my_plot(train_acc_list, losses):
plt.vlines(max_x, min(train_acc_list), max_y, colors='r' if add==0 else 'b', linestyles='dashed')
plt.hlines(max_y, 0, max_x, colors='r' if add==0 else 'b', linestyles='dashed')
plt.legend()
plt.savefig(os.path.join(os.path.dirname(train_file), 'dev_f1.png'))
plt.savefig(os.path.join(os.path.dirname(train_file), f'{MODEL}_dev_f1.png'))
plt.figure()
plt.plot(losses)
plt.savefig(os.path.join(os.path.dirname(train_file), 'loss.png'))
plt.savefig(os.path.join(os.path.dirname(train_file), f'{MODEL}_loss.png'))
logger = get_logger()

0 comments on commit bfd3914

Please sign in to comment.