forked from FlagAI-Open/FlagAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add alm model Signed-off-by: BAAI-OpenPlatform <[email protected]> Signed-off-by: ZhaodongYan1 <[email protected]> Co-authored-by: shunxing1234 <[email protected]> Co-authored-by: ZhaodongYan1 <[email protected]>
- Loading branch information
1 parent
46d646c
commit 65bcdfc
Showing
13 changed files
with
405 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# ALM 1.0 | ||
|
||
## 简介/Overview | ||
|
||
ALM 模型是基于自回归填空的通用阿拉伯语预训练模型,相关参数如下: | ||
|
||
The Arabic Language Model (ALM) 1.0 is a pretrained language model based on autoregressive blank infilling . Below shows the count of model parameters in detail. | ||
|
||
| Name | Params | Layers | Hidden Size | FFN Hidden size | Heads | Head Size | | ||
| ------- | ------ | ------ | ----------- | --------------- | ----- | --------- | | ||
| ALM 1.0 | 335M | 24 | 1024 | 4096 | 16 | 64 | | ||
|
||
|
||
|
||
## 训练数据集/Training data | ||
|
||
ALM-1.0使用了全球最大的开源阿语数据集ArabicText 2022,详细信息可参看:[ArabicText 2022](https://data.baai.ac.cn/details/ArabicText-2022) | ||
|
||
ALM-1.0 uses the largest open-source Arabic text dataset ArabicText 2022. You can check [ArabicText 2022](https://data.baai.ac.cn/details/ArabicText-2022) for more information. | ||
|
||
|
||
|
||
## 使用方式/How to use | ||
|
||
### 微调/Finetune | ||
|
||
依托于[FlagAI](https://gitee.com/link?target=https%3A%2F%2Fgithub.com%2FBAAI-Open%2FFlagAI), ALM 可以用于常见的Seq2seq 任务。 | ||
|
||
With [FlagAI](https://gitee.com/link?target=https%3A%2F%2Fgithub.com%2FBAAI-Open%2FFlagAI), one can use ALM model directly for Seq2Seq finetuning. | ||
|
||
### 快速使用/Quick start | ||
|
||
examples/alm_seq2seq/train.py提供了使用ALM做摘要、内容生成等Seq2seq任务的微调样例。 | ||
|
||
examples/alm_seq2seq/train.py provides examples to use ALM for Seq2seq finetuning task, such as text summarization and short/long text generation. | ||
|
||
examples/alm_seq2seq/generate.py提供了使用GLM模型做句子预测的样例。 | ||
|
||
examples/alm_seq2seq/train.py provides examples to use ALM for masked text prediction in an autoregressive way. | ||
|
||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright © 2022 BAAI. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License") | ||
import torch | ||
from flagai.auto_model.auto_loader import AutoLoader | ||
from flagai.model.predictor.predictor import Predictor | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
auto_loader = AutoLoader("seq2seq", | ||
model_name="ALM-1.0", | ||
model_dir="./checkpoints") | ||
model = auto_loader.get_model() | ||
tokenizer = auto_loader.get_tokenizer() | ||
|
||
model.to(device) | ||
model.eval() | ||
predictor = Predictor(model, tokenizer) | ||
|
||
test_data = ['"مقالة - سلعة: يُعد الصدق هو الحل الأفضل'] | ||
for text in test_data: | ||
print('===============================================\n') | ||
print(text, ":") | ||
for i in range(1): #generate several times | ||
print("--------------sample %d :-------------------" % (i)) | ||
print('-----------random sample: --------------') | ||
print( | ||
predictor.predict_generate_randomsample(text, | ||
out_max_length=66, | ||
top_k=10, | ||
top_p=.1, | ||
repetition_penalty=4.0, | ||
temperature=1.2)) | ||
print('-----------beam search: --------------') | ||
print( | ||
predictor.predict_generate_beamsearch(text, | ||
out_max_length=66, | ||
beam_size=10)) | ||
print() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
127.0.0.1 slots=2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright © 2022 BAAI. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License") | ||
import pandas as pd | ||
import os | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from flagai.auto_model.auto_loader import AutoLoader | ||
from flagai.trainer import Trainer | ||
from tqdm import tqdm | ||
from rouge_score import rouge_scorer | ||
from torch import argmax | ||
import sacrebleu | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
trainer = Trainer( | ||
env_type="pytorch", | ||
experiment_name="ALM_seq2seq", | ||
batch_size=1, | ||
gradient_accumulation_steps=1, | ||
lr=1e-5, | ||
weight_decay=1e-5, | ||
epochs=1, | ||
log_interval=10, | ||
eval_interval=10, | ||
load_dir=None, | ||
pytorch_device=device, | ||
save_dir="checkpoints_alm_title_generation", | ||
save_interval=1000, | ||
num_checkpoints=1, | ||
) | ||
|
||
data_dir = './data/train.tsv' | ||
val_dir = './data/valid.tsv' | ||
|
||
|
||
|
||
maxlen = 256 | ||
auto_loader = AutoLoader("lm", | ||
model_name="ALM-1.0",) | ||
model = auto_loader.get_model() | ||
tokenizer = auto_loader.get_tokenizer() | ||
|
||
|
||
def read_file(path): | ||
src = [] | ||
tgt = [] | ||
df = pd.read_csv(path, sep="\t") | ||
for idx, row in tqdm(df.iterrows()): | ||
src.append(row["source"]) | ||
tgt.append(row["target"]) | ||
return src, tgt | ||
|
||
|
||
def bleu_metric(predictions, labels, meta=None, metric="rouge-1", duplicate_rate=0.7, dataset='cnn_dm'): | ||
ref_list = [] | ||
for i in labels: | ||
i = i.tolist() | ||
ref = tokenizer.DecodeIds(i) | ||
ref_list.append(ref) | ||
pred_list = [] | ||
|
||
for prediction in predictions: | ||
buf = [] | ||
prediction = prediction.tolist() | ||
prediction = tokenizer.DecodeIds(prediction) | ||
pred_list.append(prediction) | ||
|
||
bleu_results = sacrebleu.corpus_bleu(pred_list, [ref_list]) | ||
bleu_score = bleu_results.score | ||
return bleu_score | ||
|
||
def rouge_metric(predictions, labels, meta=None, metric="rouge-1", duplicate_rate=0.7, dataset='cnn_dm'): | ||
metric_dict = {"rouge-1": "rouge1", "rouge-2": "rouge2", "rouge-l": "rougeLsum"} | ||
ref_list = [] | ||
for i in labels: | ||
i = i.tolist() | ||
ref = tokenizer.DecodeIds(i) | ||
ref_list.append(ref) | ||
pred_list = [] | ||
for prediction in predictions: | ||
buf = [] | ||
prediction = prediction.tolist() | ||
prediction = tokenizer.DecodeIds(prediction) | ||
pred_list.append(prediction) | ||
scorer = rouge_scorer.RougeScorer([metric_dict[metric]], use_stemmer=True) | ||
scores = [scorer.score(pred, ref) for pred, ref in zip(pred_list, ref_list)] | ||
scores = [score[metric_dict[metric]].fmeasure * 100 for score in scores] | ||
scores = sum(scores) / len(scores) | ||
return scores | ||
|
||
|
||
|
||
class ALMSeq2seqDataset(Dataset): | ||
|
||
def __init__(self, | ||
sents_src, | ||
sents_tgt, | ||
tokenizer, | ||
max_src_length=512, | ||
max_tgt_length=512): | ||
super(ALMSeq2seqDataset, self).__init__() | ||
self.sents_src = sents_src | ||
self.sents_tgt = sents_tgt | ||
self.tokenizer = tokenizer | ||
self.max_src_length = max_src_length | ||
self.max_tgt_length = max_tgt_length | ||
self.no_block_position = False | ||
|
||
def __getitem__(self, i): | ||
source_text = self.sents_src[i] | ||
target_text = self.sents_tgt[i] | ||
data = self.tokenizer.encode_plus(source_text=source_text, target_text=target_text, max_length=512) | ||
|
||
return data | ||
|
||
def __len__(self): | ||
|
||
return len(self.sents_src) | ||
|
||
|
||
class ALMCollateFN(): #padding process in each batch | ||
|
||
def __init__(self, pad_id): | ||
self.pad_id = pad_id | ||
|
||
def pad_token(self, tokens, max_length): | ||
pad_len = max_length - len(tokens) | ||
tokens += [self.pad_id] * pad_len | ||
return tokens | ||
|
||
def pad_position_ids(self, position_ids, max_length): | ||
pad_len = max_length - len(position_ids[0]) | ||
position_ids[0] += [len(position_ids[0]) + x for x in range(pad_len)] | ||
position_ids[1] += [1] * pad_len | ||
return position_ids | ||
|
||
def pad_loss_mask(self, loss_mask, max_length): | ||
pad_len = max_length - len(loss_mask) | ||
loss_mask += [0] * pad_len | ||
return loss_mask | ||
|
||
def __call__(self, batch): | ||
input_ids = [data["input_ids"] for data in batch] | ||
target_ids = [data["target_ids"] for data in batch] | ||
position_ids = [data["position_ids"] for data in batch] | ||
attention_mask = [data['attention_mask'] for data in batch] | ||
loss_mask = [data['loss_mask'] for data in batch] | ||
|
||
max_length = max([len(t) for t in input_ids]) | ||
for i in range(len(input_ids)): | ||
input_ids[i] = self.pad_token(input_ids[i], max_length) | ||
target_ids[i] = self.pad_token(target_ids[i], max_length) | ||
position_ids[i] = self.pad_position_ids(position_ids[i], | ||
max_length) | ||
loss_mask[i] = self.pad_loss_mask(loss_mask[i], max_length) | ||
return { | ||
'input_ids': torch.LongTensor(input_ids), | ||
'labels': torch.LongTensor(target_ids), | ||
'position_ids': torch.LongTensor(position_ids), | ||
'attention_mask': torch.LongTensor(attention_mask), | ||
'loss_mask': torch.LongTensor(loss_mask) | ||
} | ||
|
||
|
||
sents_src, sents_tgt = read_file(data_dir) | ||
vail_src, vail_tgt = read_file(val_dir) | ||
|
||
my_collate_fn = ALMCollateFN( | ||
pad_id=tokenizer.get_command_id('pad')) | ||
|
||
data_len = len(sents_tgt) | ||
train_size = int(data_len * 0.8) | ||
train_src = sents_src | ||
train_tgt = sents_tgt | ||
|
||
val_src = vail_src | ||
val_tgt = vail_tgt | ||
|
||
train_dataset = ALMSeq2seqDataset(train_src, | ||
train_tgt, | ||
tokenizer=tokenizer) | ||
val_dataset = ALMSeq2seqDataset(val_src, | ||
val_tgt, | ||
tokenizer=tokenizer) | ||
trainer.train(model, | ||
train_dataset=train_dataset, | ||
valid_dataset=val_dataset, | ||
metric_methods=[['rouge_scorer', rouge_metric], ['bleu', bleu_metric]], | ||
collate_fn=my_collate_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.