Skip to content

Commit

Permalink
Add Alpaca (CarperAI#400)
Browse files Browse the repository at this point in the history
* Add loss masking to SFT (both from prompt & padding tokens)
* Add Alpaca example
  • Loading branch information
cat-state authored Mar 29, 2023
1 parent 114998b commit f63f4bf
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 3 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('
trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0])
```

#### Using a prompt-completion dataset

```python
trainer = trlx.train('gpt2', samples=[['Question: 1 + 2 Answer:', '3'], ['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']])
```

#### Trainers provide a wrapper over their underlying model

```python
Expand Down
11 changes: 11 additions & 0 deletions examples/alpaca/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Alpaca

Finetune a model on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)
```bash
python sft_alpaca.py --model_name EleutherAI/gpt-j-6B --dataset tatsu-lab/alpaca
```

Finetune a model on [Alpaca-Cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned)
```bash
python sft_alpaca.py --model_name EleutherAI/gpt-j-6B --dataset yahma/alpaca-cleaned
```
102 changes: 102 additions & 0 deletions examples/alpaca/sft_alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
import os
from argparse import ArgumentParser
from typing import Dict, List

from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_sft_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def preprocess(instruction: str, input: str, output: str):
"""Build Alpaca prompt and output from instruction and input/output examples"""
if input:
prefix = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request."
)
prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
return [prompt, output]
else:
prefix = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
)
prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Response:\n"
return [prompt, output]


def main(hparams={}, model_name="EleutherAI/gpt-j-6B", dataset="tatsu-lab/alpaca"):
config = default_sft_config()
config = config.evolve(
train=dict(
total_steps=2400,
batch_size=4,
seq_length=1024,
),
model=dict(
model_path=model_name,
),
tokenizer=dict(
tokenizer_path=model_name,
),
optimizer=dict(kwargs=dict(lr=2e-5)),
scheduler=dict(kwargs=dict(eta_min=2e-5)),
method=dict(
gen_kwargs=dict(
max_new_tokens=256,
)
),
)

# Merge sweep config with default config if given
config = TRLConfig.update(config.to_dict(), hparams)

# alpaca = load_dataset("tatsu-lab/alpaca", split="train")
alpaca = load_dataset(dataset, split="train")
alpaca = [preprocess(x["instruction"], x["input"], x["output"]) for x in alpaca]

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1,
)

def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> Dict[str, List[float]]:
sentiments = list(map(get_positive_score, sentiment_fn(outputs)))
return {"sentiments": sentiments}

imdb = load_dataset("imdb", split="test")
bad_reviews = imdb.filter(lambda sample: sample["label"] == 0).select(range(256))
zs_rewrite = [preprocess("Rewrite the input into a positive review.", x["text"][:1024], "")[0] for x in bad_reviews]

trainer = trlx.train(
samples=alpaca,
eval_prompts=zs_rewrite,
metric_fn=metric_fn,
config=config,
)

slug = f"{model_name.split('/')[-1]}-{dataset.split('/')[-1]}"
trainer.save_pretrained(f"{slug}-sft")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("override_hparams", type=str, default="{}", nargs="?")
parser.add_argument("--model_name", type=str, default="EleutherAI/gpt-j-6B")
parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca")

args = parser.parse_args()
hparams = json.loads(args.override_hparams)

main(hparams, args.model_name, args.dataset)
28 changes: 28 additions & 0 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,34 @@ def tokenize_dialogue( # noqa: C901
return truncated


class DialogStore(BaseRolloutStore):
def __init__(self, dialogs: List[List[DialogMessage]], tokenizer: PreTrainedTokenizer):
super().__init__()
self.tokenizer = tokenizer
attention_masks = [torch.ones(sum(len(m.tokens) for m in d), dtype=torch.bool) for d in dialogs]
input_ids = [torch.tensor([t for m in d for t in m.tokens], dtype=torch.long) for d in dialogs]
# -100 is the ignore index for CrossEntropyLoss
labels = [
torch.tensor([t if m.is_output else -100 for m in d for t in m.tokens], dtype=torch.long) for d in dialogs
]
self.history = [
dict(input_ids=i, attention_mask=a, labels=l) for i, a, l in zip(input_ids, attention_masks, labels)
]

def create_loader(self, batch_size: int, shuffle=False) -> DataLoader:
hf_collate_fn = DataCollatorWithPadding(self.tokenizer)

def collate_fn(elems: Iterable[dict]):
batch = hf_collate_fn(
{"input_ids": [e["input_ids"] for e in elems], "attention_mask": [e["attention_mask"] for e in elems]}
)
labels = hf_collate_fn([{"input_ids": e["labels"]} for e in elems])["input_ids"]
batch["labels"] = labels
return batch

return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)


@register_datapipeline
class PromptPipeline(BasePipeline):
"""
Expand Down
20 changes: 19 additions & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from trlx.data.configs import TRLConfig
from trlx.data.method_configs import MethodConfig, register_method
from trlx.pipeline.offline_pipeline import (
DialogStore,
PromptPipeline,
tokenize_dialogue,
)
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer

Expand Down Expand Up @@ -36,7 +41,13 @@ def get_arch(self, config):
return AutoModelForCausalLM.from_pretrained(config.model.model_path)

def loss(self, batch):
loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=batch.input_ids).loss
if "labels" in batch:
labels = batch.labels.clone()
else:
labels = batch.input_ids.clone()
labels[~batch.attention_mask.bool()] = -100

loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss
stats = {"loss": loss}

return loss, stats
Expand All @@ -55,3 +66,10 @@ def prepare_learning(self):
self.n_updates_per_batch = 1
self.total_steps = self.config.train.epochs * len(train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def make_experience(self, samples, seq_length):
if isinstance(samples[0], str):
self.store = PromptPipeline(samples, seq_length, self.tokenizer)
else:
dialogs = [tokenize_dialogue(d, self.tokenizer, seq_length) for d in samples]
self.store = DialogStore(dialogs, self.tokenizer)
3 changes: 1 addition & 2 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def train( # noqa: C901
if rewards is not None:
trainer.make_experience(samples, rewards, config.train.seq_length)
else:
trainer.store = get_pipeline(config.train.pipeline)(samples, max_prompt_length, trainer.tokenizer)

trainer.make_experience(samples, config.train.seq_length)
else:
raise ValueError("Either `samples` or `reward_fn` should be given for training")

Expand Down

0 comments on commit f63f4bf

Please sign in to comment.