Skip to content

Averyyy/distillEBWM

This branch is 3 commits ahead of jongwooko/distillm:master.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

4a9ace9 · Jan 30, 2025

History

16 Commits
Jan 9, 2025
Feb 6, 2024
Sep 9, 2024
Jan 30, 2025
Jan 30, 2025
Feb 6, 2024
Dec 19, 2024
Jan 30, 2025
Jan 30, 2025
Jan 9, 2025
Jan 30, 2025
Jan 9, 2025
Sep 20, 2024
Jan 30, 2025
Aug 13, 2024
Aug 13, 2024
Jan 9, 2025
Jan 30, 2025
Dec 19, 2024
Dec 19, 2024
Dec 19, 2024
Feb 6, 2024
Dec 19, 2024
Jan 30, 2025
Feb 6, 2024
Jan 30, 2025

Repository files navigation

DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)

Official PyTorch implementation of DistiLLM, as presented in our paper:

DistiLLM: Towards Streamlined Distillation for Large Language Models
Jongwoo Ko, Sungnyun Kim, Tianyi Chen, Se-Young Yun
KAIST AI and Microsoft

🚀 Updates

  • (24.08.12) Remove the dependency on the local transformers, which are outdated. You can work with various types of recent LLMs!
  • (24.05.01) Our paper has been accepted in ICML 2024. We are open to receiving any discussions and will reflect them in the camera-ready version. Looking forward to seeing you in Vienna!
  • (24.03.13) Release LoRA checkpoints for OpenLLaMa2-3B

Environment

bash install.sh

Our code is based on this commit of HuggingFace Transformers by following MiniLLM.

Data

Resources

  • The training/evaluation intruction-response data before processing can be downloaded from this link.
  • The plain-text corpus D PT can be download from the HugginFace datasets repository.

Data Processing

Get plain-text corpus D PT :

python3 tools/get_openwebtext.py

This script will replace the continuous \n in each document with a special token "<@x(x!>" and write each document in OpenWebText in a line, which is convenient for parallel processing. In data/openwebtext/data.txt, we give an example of the resulting format. You can follow this format to prepare other corpus beyond OpenWebText.

Tokenize the data and store them in binary files:

bash scripts/gpt2/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/gpt2/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Train / Validation Data

bash scripts/opt/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/opt/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data

bash scripts/llama/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/llama/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data

Base Pre-trained Models

To run fine-tuning or standard KD baselines, you need to download the model checkpoints from [Huggingface Model Hub] and put them in checkpoints/. For example, for gpt2-large, you can download the model from this link and put them in checkpoints/gpt2-large.

Alternatively, you can also change the CKPT variable in each script to the corresponding model name to enable Transformers to download the base models automatically. For example, set CKPT="gpt2-large" in scripts/gpt2/sft/sft_large.sh causes download of the gpt2-large base model from the HugginFace model hub.

Train

We provide example commands for GPT-2 models. Similar scripts for model families can be found in scripts/opt and scripts/openllama2. All our experiments are conducted on 4 * 40A100, which can be reduced for small models.

Baselines

The final checkpoints are selected by the ROUGE-L scores.

Fine-tune the teacher models

bash scripts/gpt2/sft/sft_xlarge.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

SFT Baselines

bash scripts/gpt2/sft/sft_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

KD Baselines

bash scripts/gpt2/kd/kd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

SeqKD Baselines

Generate and process responses with the teacher:

bash scripts/gpt2/tools/generate_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/tools/process_pseudo_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Fine-tune the model with SeqKD:

bash scripts/gpt2/seqkd/seqkd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Student Initialization

The final checkpoints are selected by the validation loss.

bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

ImitKD Baselines

bash scripts/gpt2/imitkd/imitkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

MiniLLM Baselines

bash scripts/gpt2/minillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

GKD Baselines

bash scripts/gpt2/gkd/gkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

DistiLLM

The final checkpoints are selected by the validation loss.

bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

The final checkpoints are selected by the ROUGE-L scores.

bash scripts/gpt2/distillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Run Evaluation

bash scripts/gpt2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
bash scripts/opt/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 
bash scripts/openllama2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 

Results

DistiLLM outperforms other KD baselines in terms of both generation performance and training speed for various model families such as GPT-2, OPT, and OpenLLaMA.

Checkpoints (OpenLLaMA-3B)

We share the LoRA weights for OpenLLaMA-3B in google drive.

Acknowledgement

Our code is based on the code of ICLR2024 MiniLLM: Knowledge Distillation of Large Language Models.

Star History

Star History Chart

BibTeX

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{kodistillm,
  title={DistiLLM: Towards Streamlined Distillation for Large Language Models},
  author={Ko, Jongwoo and Kim, Sungnyun and Chen, Tianyi and Yun, Se-Young},
  booktitle={Forty-first International Conference on Machine Learning}
}

Contact

Releases

No releases published

Packages

No packages published

Languages

  • Python 63.6%
  • Shell 36.4%