Skip to content

Commit

Permalink
[feat] Add LLaMa Model support for PPO (CarperAI#375)
Browse files Browse the repository at this point in the history
* init llama support

* add sentiments example

* update sentiment config

* fix style

* add reference

---------

Co-authored-by: Duy Phung <duyphung@cw-prod-login-0.cw-prod-login.tenant-stabilitytraining-704a100.svc.tenant.chi.local>
Co-authored-by: Duy Phung <duyphung@cw-prod-a100-cu117-49.cw-prod-compute.tenant-stabilitytraining-704a100.svc.tenant.chi.local>
  • Loading branch information
3 people authored Mar 26, 2023
1 parent b0c4ea9 commit 086a905
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 0 deletions.
111 changes: 111 additions & 0 deletions examples/ppo_sentiments_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import (
ModelConfig,
OptimizerConfig,
PPOConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def llama_config():
return TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=10000,
batch_size=32,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
save_best=False,
),
model=ModelConfig(model_path="decapoda-research/llama-7b-hf", num_layers_unfrozen=2),
tokenizer=TokenizerConfig(tokenizer_path="decapoda-research/llama-7b-hf", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-5)),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=128,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
horizon=10000,
gamma=1,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1,
scale_reward="ignored",
ref_mean=None,
ref_std=None,
cliprange_reward=10,
gen_kwargs=dict(
max_new_tokens=40,
top_k=0,
top_p=1.0,
do_sample=True,
),
),
)


def main(hparams={}):
# Merge sweep config with default config if given
config = TRLConfig.update(llama_config().to_dict(), hparams)

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
135 changes: 135 additions & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,138 @@ def forward( # noqa: max-complexity
)


class LlamaModelBranch(ModelBranch):
def _make_causal_mask(self, input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)

if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_states, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = self._make_causal_mask(
input_shape, hidden_states.dtype, past_key_values_length=past_key_values_length
).to(hidden_states.device)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]).to(
hidden_states.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask

def forward(
self,
hidden_states: torch.Tensor,
output_shape: Tuple[int, int],
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithValue]:
"""Reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.decoder_blocks):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.final_norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
lm_logits = self.lm_head(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
outputs = (lm_logits,) + (None,) + (None,)
return outputs

return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


# Seq2Seq architectures


Expand Down Expand Up @@ -1105,13 +1237,16 @@ def hf_get_branch_class(
]
opt_branch_supported_archs = ["OPTForCausalLM"]
bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
arch = config.architectures[0]
if arch in gpt_branch_supported_archs:
return GPTModelBranch
elif arch in opt_branch_supported_archs:
return OPTModelBranch
elif arch in bloom_branch_supported_archs:
return BloomModelBranch
elif arch in llama_branch_supported_archs:
return LlamaModelBranch
else:
all_supported_archs = sum(
[
Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, config, **kwargs): # noqa: C901
self.tokenizer.sep_token = "<sep>"
if config.model.model_arch_type != "seq2seq":
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0]
if not isinstance(config.model.model_path, str):
Expand Down
2 changes: 2 additions & 0 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def hf_get_decoder_final_norm(model: nn.Module) -> float:
norm_attrs = (
"transformer.ln_f",
"model.decoder.final_layer_norm",
"model.norm",
"decoder.final_layer_norm",
"gpt_neox.final_layer_norm",
)
Expand All @@ -142,6 +143,7 @@ def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]:
hidden_layers_attrs = (
"h",
"layers",
"model.layers",
"decoder.layers",
"transformer.h",
"model.decoder.layers",
Expand Down

0 comments on commit 086a905

Please sign in to comment.