Skip to content

Commit

Permalink
Merge pull request airaria#47 from airaria/update_exmaple
Browse files Browse the repository at this point in the history
update mnli_exmaple and bert-emd
  • Loading branch information
airaria authored Mar 1, 2021
2 parents fb009a6 + d6e8515 commit a9f1d81
Show file tree
Hide file tree
Showing 44 changed files with 1,143 additions and 10,778 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ Check our paper through [ACL Anthology](https://www.aclweb.org/anthology/2020.ac

## News

**Mar 1, 2021**

* **BERT-EMD and custom distiller**

* We added an experiment with [BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/) in the [MNLI exmaple](examples/mnli_example/). BERT-EMD allows each intermediate student layer to learn from any intermediate teacher layers adaptively, bassed on optimizing Earth Mover’s Distance. So there is no need to specify the mathcing scheme.
* We have written a new [EMDDistiller](examples/mnli_example/distiller_emd.py) to perform BERT-EMD. It demonstrates how to write a custom distiller.

* **updated MNLI example**

* We removed the pretrained_pytorch_bert and used transformers library instead in all the MNLI exmaples.


**Nov 11, 2020**

* **Updated to 0.2.1**:
Expand Down
12 changes: 12 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@

## 新闻

**Mar 1, 2021**

* **BERT-EMD示例与自定义distiller**

*[MNLI示例](examples/mnli_example)中增加了[BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/)算法的实现。BERT-EMD通过优化中间层之间的Earth Mvoer's Distance以自适应地调整教师与学生之间中间层匹配,而无需人工指定。
* BERT-EMD以自定义distiller的方式([EMDDistiller](examples/mnli_example/distiller_emd.py))实现,可作为自定义distiller的参考。

* **MNLI示例更新**

* 更新了MNLI任务上的蒸馏示例,新代码不再依赖pytorch_pretrained_bert而使用transofrmers。


**Nov 11, 2020**

* **版本更新至0.2.1**:
Expand Down
40 changes: 33 additions & 7 deletions examples/mnli_example/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
[**中文说明**](README_ZH.md) | [**English**](README.md)

This example demonstrates distilltion on MNLI task.
This example demonstrates distilltion on MNLI task and **how to write a new distiller**.

* run_mnli_train.sh : trains a teacher model (bert-base-cased) on MNLI.
* run_mnli_train.sh : trains a teacher model (bert-base) on MNLI.
* run_mnli_distill_T4tiny.sh : distills the teacher to T4tiny.
* run_mnli_distill_multiteacher.sh : runs multi-teacher distillation,distilling several teacher models into a student model.
* run_mnli_distill_T4tiny_emd.sh : distills the teacher to T4tiny with many-to-many intermediate matches using EMD, so there is no need to specifying the mathcing scheme. This example also demonstrates how to write a custom distiller (see below for details).
* run_mnli_distill_multiteacher.sh : runs multi-teacher distillation, distilling several teacher models into a student model.

Set the following variables in the shell scripts before running:
Examples have been tested on **PyTorch==1.2.0, transformers==3.0.2**.

## Run

1. Set the following variables in the bash scripts before running:

* BERT_DIR : where BERT-base-cased stores,including vocab.txt, pytorch_model.bin, bert_config.json
* OUTPUT_ROOT_DIR : this directory stores logs and trained model weights
* DATA_ROOT_DIR : it includes MNLI dataset:
* \$\{DATA_ROOT_DIR\}/MNLI/train.tsv
* \$\{DATA_ROOT_DIR\}/MNLI/dev_matched.tsv
* \$\{DATA_ROOT_DIR\}/MNLI/dev_mismatched.tsv
* The trained teacher weights file *trained_teacher_model* has to be specified if running run_mnli_distill_T4tiny.sh
* Multiple teacher weights file *trained_teacher_model_1, trained_teacher_model_2, trained_teacher_model_3* has to be specified if running run_mnli_distill_multiteacher.sh

2. Set the path to BERT:
* If you are running run_mnli_train.sh: open jsons/TrainBertTeacher.json and set "vocab_file","config_file"和"checkpoint" which are under the key "student".
* If you are running run_mnli_distill_T4tiny.sh or run_mnli_distill_T4tiny_emd.sh: open jsons/DistillBertToTiny.json and set "vocab_file", "config_file" and"checkpoint" which are under the key "teachers".
* If you are running run_mnli_distill_multiteacher.sh: open jsons/DistillMultiBert.json and set all the "vocab_file","config_file" and "checkpoint" under the key "teachers". You can also add more teachers to the json.

3. Run the bash script and have fun.

## BERT-EMD and custom distiller
[BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/) allows each intermediate student layer to learn from any intermediate teacher layers adaptively, bassed on optimizing Earth Mover’s Distance. So there is no need to specify the mathcing scheme.

Based on the [original implementation](https://github.com/lxk00/BERT-EMD), we have written a new distiller (EMDDistiller) to implement a simplified viersion of BERT-EMD (which ignores mappings between attentions). The code of the algorithm is in distiller_emd.py. The EMDDistiller is much like the other distillers:
```python
from distiller_emd import EMDDistiller
distiller = EMDDistiller(...)
with distiller:
distiller.train(...)
```
see main.emd.py for detailed usages.

EMDDistiller requires pyemd package:
```bash
pip install pyemd
```
44 changes: 36 additions & 8 deletions examples/mnli_example/README_ZH.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
[**中文说明**](README_ZH.md) | [**English**](README.md)

这个例子展示MNLI句对分类任务上的蒸馏。GLUE中的其他任务的蒸馏也类似
这个例子展示MNLI句对分类任务上的蒸馏,同时提供了一个**自定义distiller**的例子

* run_mnli_train.sh : 在MNLI数据上训练教师模型(bert-base-cased)
* run_mnli_distill_T4tiny.sh : 在MNLI上蒸馏教师模型到T4Tiny
* run_mnli_distill_multiteacher.sh : 执行多教师蒸馏,将多个教师模型压缩到一个学生模型
* run_mnli_train.sh : 在MNLI数据上训练教师模型(bert-base)。
* run_mnli_distill_T4tiny.sh : 在MNLI上蒸馏教师模型到T4Tiny。
* run_mnli_distill_T4tiny_emd.sh:使用EMD方法自动计算隐层与隐层的匹配,而无需人工指定。该例子同时展示了如何自定义distiller(见下文详解)。
* run_mnli_distill_multiteacher.sh : 多教师蒸馏,将多个教师模型压缩到一个学生模型。

**PyTorch==1.2.0,transformers==3.0.2** 上测试通过。

## 运行

1. 运行以上任一个脚本前,请根据自己的环境设置sh文件中相应变量:

运行脚本前,请根据自己的环境设置相应变量:

* BERT_DIR : 存放BERT-base-cased模型的目录,包含vocab.txt, pytorch_model.bin, bert_config.json
* OUTPUT_ROOT_DIR : 存放训练好的模型和日志
* DATA_ROOT_DIR : 包含MNLI数据集:
* \$\{DATA_ROOT_DIR\}/MNLI/train.tsv
* \$\{DATA_ROOT_DIR\}/MNLI/dev_matched.tsv
* \$\{DATA_ROOT_DIR\}/MNLI/dev_mismatched.tsv
* 如果是运行 run_mnli_distill_T4tiny.sh, 还需要指定训练好的教师模型权重文件 trained_teacher_model
* 如果是运行 run_mnli_distill_multiteacher.sh, 需要指定多个训练好的教师模型权重文件 trained_teacher_model_1, trained_teacher_model_2, trained_teacher_model_3

2. 设置BERT模型路径:
* 如果运行run_mnli_train.sh,修改jsons/TrainBertTeacher.json中"student"键下的"vocab_file","config_file"和"checkpoint"路径
* 如果运行 run_mnli_distill_T4tiny.sh 或 run_mnli_distill_T4tiny_emd.sh,修改jsons/DistillBertToTiny.json中"teachers"键下的"vocab_file","config_file"和"checkpoint"路径
* 如果运行 run_mnli_distill_multiteacher.sh, 修改jsons/DistillMultiBert.json中"teachers"键下的所有"vocab_file","config_file"和"checkpoint"路径。可以自行添加更多teacher。

3. 设置完成,执行sh文件开始训练。

## BERT-EMD与自定义distiller
[BERT-EMD](https://www.aclweb.org/anthology/2020.emnlp-main.242/) 通过优化中间层之间的Earth Mvoer's Distance以自适应地调整教师与学生之间中间层匹配。

我们参照了其[原始实现](https://github.com/lxk00/BERT-EMD),并以distiller的形式实现了其一个简化版本EMDDistiller(忽略了attention间的mapping)。
BERT-EMD相关代码位于distiller_emd.py。EMDDistiller使用方法与其他distiller无太大差异:
```python
from distiller_emd import EMDDistiller
distiller = EMDDistiller(...)
with distiller:
distiller.train(...)
```
使用方式详见 main.emd.py。

EMDDistiller要求pyemd包:
```bash
pip install pyemd
```
42 changes: 7 additions & 35 deletions examples/mnli_example/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,12 @@ def parse(opt=None):

## Required parameters

parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints will be written.")

## Other parameters
parser.add_argument("--data_dir", default=None, type=str)
parser.add_argument("--do_lower_case", action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--max_seq_length", default=416, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--max_seq_length", default=128, type=int)
parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.")
parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
Expand All @@ -30,9 +23,6 @@ def parse(opt=None):
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
"of training.")
parser.add_argument("--verbose_logging", default=False, action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
default=False,
action='store_true',
Expand All @@ -51,47 +41,29 @@ def parse(opt=None):
help="Whether to use 16-bit float precisoin instead of 32-bit")

parser.add_argument('--random_seed',type=int,default=10236797)
parser.add_argument('--load_model_type',type=str,default='bert',choices=['bert','all','none'])
parser.add_argument('--weight_decay_rate',type=float,default=0.01)
parser.add_argument('--do_eval',action='store_true')
parser.add_argument('--PRINT_EVERY',type=int,default=200)
parser.add_argument('--weight',type=float,default=1.0)
parser.add_argument('--ckpt_frequency',type=int,default=2)

parser.add_argument('--tuned_checkpoint_T',type=str,default=None)
parser.add_argument('--tuned_checkpoint_Ts',nargs='*',type=str)
parser.add_argument("--temperature", default=1, type=float)

parser.add_argument('--tuned_checkpoint_S',type=str,default=None)
parser.add_argument("--init_checkpoint_S", default=None, type=str)
parser.add_argument("--bert_config_file_T", default=None, type=str, required=True)
parser.add_argument("--bert_config_file_S", default=None, type=str, required=True)
parser.add_argument("--temperature", default=1, type=float, required=False)
parser.add_argument("--teacher_cached",action='store_true')

parser.add_argument('--s_opt1',type=float,default=1.0, help="release_start / step1 / ratio")
parser.add_argument('--s_opt2',type=float,default=0.0, help="release_level / step2")
parser.add_argument('--s_opt3',type=float,default=1.0, help="not used / decay rate")
parser.add_argument('--schedule',type=str,default='warmup_linear_release')

parser.add_argument('--no_inputs_mask',action='store_true')
parser.add_argument('--no_logits', action='store_true')
parser.add_argument('--output_att_score',default='true',choices=['true','false'])
parser.add_argument('--output_att_sum', default='false',choices=['true','false'])
parser.add_argument('--output_encoded_layers' ,default='true',choices=['true','false'])
parser.add_argument('--output_attention_layers',default='true',choices=['true','false'])
parser.add_argument('--matches',nargs='*',type=str)
parser.add_argument('--task_name',type=str,choices=list(processors.keys()))
parser.add_argument('--aux_task_name',type=str,choices=list(processors.keys()),default=None)
parser.add_argument('--aux_data_dir', type=str)

parser.add_argument('--only_load_embedding',action='store_true')
parser.add_argument('--matches',nargs='*',type=str)
parser.add_argument('--model_config_json',type=str)
parser.add_argument('--do_test',action='store_true')


global args
if opt is None:
args = parser.parse_args()
else:
args = parser.parse_args(opt)


if __name__ == '__main__':
print (args)
parse(['--SAVE_DIR','test'])
Expand Down
Loading

0 comments on commit a9f1d81

Please sign in to comment.