trlX allows you to fine-tune 🤗 Hugging Face supported language models (gpt2
, gpt-j
, gpt-neo
and gpt-neox
based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.
You can read more about trlX in our documentation.
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .
You can train a model using a reward function or a reward-labeled dataset.
trainer = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])
trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])
trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.py
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
For more usage see examples
For development check out these guidelines and also read our docs
Many thanks to Leandro von Werra for contributing with trl, a library that initially inspired this repo.