Skip to content
/ EasyLM Public
forked from young-geng/EasyLM

Easy to use model parallel large language models in JAX/Flax with pjit support on cloud TPU pods.

License

Notifications You must be signed in to change notification settings

lixw668/EasyLM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. EasyLM can scale up LLM training to hundreds of TPU/GPU accelerators by leveraging JAX's pjit functionality.

Building on top of Hugginface's transformers and datasets, this repo provides an easy to use and easy to customize codebase for training large langauge models without the complexity in many other frameworks.

EasyLM is built with JAX/Flax. By leveraging JAX's pjit utility, EasyLM is able to train large model that doesn't fit on a single accelerator by sharding the model weights and training data across multiple accelerators. Currently, EasyLM supports multiple TPU/GPU training in a single host as well as multi-host training on Google Cloud TPU Pods.

Currently, the following models are supported:

Koala

Koala is our new chatbot fine-tuned on top of LLaMA. If you are interested in our Koala chatbot, you can check out the blogpost and documentation for running it locally.

Installation

The installation method differs between GPU hosts and Cloud TPU hosts. The first step is to pull from GitHub.

git clone https://github.com/young-geng/EasyLM.git
cd EasyLM
export PYTHONPATH="${PWD}:$PYTHONPATH"

Installing on GPU Host

The GPU environment can be installed via Anaconda.

conda env create -f scripts/gpu_environment.yml
conda activate EasyLM

Installing on Cloud TPU Host

The TPU host VM comes with Python and PIP pre-installed. Simply run the following script to set up the TPU host.

./scripts/tpu_vm_setup.sh

The EasyLM documentations can be found in the docs directory.

Credits

  • The LLaMA implementation is from JAX_llama
  • The JAX/Flax GPT-J and RoBERTa implementation are from transformers
  • Most of the JAX utilities are from mlxu
  • The codebase is heavily inspired by JAXSeq

About

Easy to use model parallel large language models in JAX/Flax with pjit support on cloud TPU pods.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.6%
  • Shell 2.4%