Skip to content

Commit

Permalink
Refactor README (huggingface#460)
Browse files Browse the repository at this point in the history
* v1

* update

* link

* nits
  • Loading branch information
younesbelkada authored Jul 3, 2023
1 parent 0fe603e commit aa9770c
Showing 1 changed file with 81 additions and 8 deletions.
89 changes: 81 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,38 @@
</div>

# TRL - Transformer Reinforcement Learning
> Train transformer language models with reinforcement learning.
> Full stack transformer language models with reinforcement learning.
<p align="center">
<a href="https://github.com/lvwerra/trl/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/lvwerra/trl.svg?color=blue">
</a>
<a href="https://huggingface.co/docs/trl/index">
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
</a>
<a href="https://github.com/lvwerra/trl/releases">
<img alt="GitHub release" src="https://img.shields.io/github/release/lvwerra/trl.svg">
</a>
</p>


## What is it?
With `trl` you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>

`trl` is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.

**Highlights:**
- `PPOTrainer`: A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- `AutoModelForCausalLMWithValueHead` & `AutoModelForSeq2SeqLMWithValueHead`: A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- Example: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier.

## How it works
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/lvwerra/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.

## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:

1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
Expand Down Expand Up @@ -52,8 +72,59 @@ pip install -e .

## How to use

### Example
This is a basic example on how to use the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
### `SFTTrainer`

This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.

```python
# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)

# train
trainer.train()
```

### `RewardTrainer`

This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.

```python
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

### `PPOTrainer`

This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.

```python
# imports
Expand Down Expand Up @@ -99,6 +170,8 @@ For a detailed example check out the example python script `examples/sentiment/s
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>

Have a look at more examples inside [`examples/`](https://github.com/lvwerra/trl/tree/main/examples) folder.

## References

### Proximal Policy Optimisation
Expand Down

0 comments on commit aa9770c

Please sign in to comment.