Skip to content

Commit

Permalink
Update ILQL details (CarperAI#156)
Browse files Browse the repository at this point in the history
* feat(offline_orchestrator): delegate spliting & enable interleaving

* feat(ilql_models): weight and tie awac to actions & add tensor stats

* feat(ilql_randomwalks): change starting state from bos to first node

as it already is in the ppo_randomwalks

* feat(ilql_sentiments): keep first sentences only

* fix(offline_orchestrator): force returns dtype

* chore(configs): plug with beta

* refactor(*): satisfy isort

* fix(ilql_models): proper masking

* feat(offline_orchestrator): add variable truncation

* chore(examples): revert to old behaviour

* refactor(*): style check

* feat(offline_orchestrator): convert samples from tuples to lists

* refactor(offline_orchestrator): printing of dataset statistics

* refactor(offline_orchestrator): simplify if statement

* chore(trlx): document `dataset` argument

* refactor(offline_orchestrator): remove duplication

* chore(trlx): update docs
  • Loading branch information
maxreciprocate authored Jan 11, 2023
1 parent 24b041a commit 0cb8438
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 78 deletions.
2 changes: 1 addition & 1 deletion configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ train:
pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
trainer: "AccelerateILQLTrainer"

seed: 1000

model:
Expand Down Expand Up @@ -39,6 +38,7 @@ method:
cql_scale: 0.1
awac_scale: 1
alpha: 0.001
beta: 0
steps_for_target_q_sync: 5
two_qs: true
gen_kwargs:
Expand Down
1 change: 0 additions & 1 deletion examples/ilql_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
imdb = load_dataset("imdb", split="train+test")

trlx.train(
"gpt2",
dataset=(imdb["text"], imdb["label"]),
eval_prompts=["I don't know much about Hungarian underground"] * 64,
metric_fn=metric_fn,
Expand Down
1 change: 1 addition & 0 deletions examples/randomwalks/configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ method:
cql_scale: 0.1
awac_scale: 1
alpha: 0.1
beta: 0
steps_for_target_q_sync: 5
two_qs: true
gen_kwargs:
Expand Down
2 changes: 2 additions & 0 deletions examples/randomwalks/ilql_randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def main(hparams={}):

metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed)
rewards = metric_fn(walks)["optimality"]
# split each random walk into (starting state, rest of the walk)
walks = [[walk[:1], walk[1:]] for walk in walks]

trlx.train(
GPT2Config(n_layer=6, n_embd=144, vocab_size=23),
Expand Down
157 changes: 107 additions & 50 deletions trlx/orchestrator/offline_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
from typing import List, Union

import numpy as np
import torch

from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage
from trlx.utils import print_rank_0


def tokenize_dialogue( # noqa: C901
dialogue: Union[str, List[str]], tokenizer, max_length=2048, truncation_side="left"
) -> List[int]:
"""
Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...)
"""
if isinstance(dialogue, str):
dialogue = [tokenizer.bos_token, dialogue]
elif isinstance(dialogue, tuple):
dialogue = list(dialogue)
dialogue[-1] += tokenizer.eos_token

out = []
ctx_length = max_length
if truncation_side == "left":
for phrase in reversed(dialogue):
tokens = tokenizer(phrase).input_ids[-ctx_length:]
ctx_length -= len(tokens)
out.insert(0, tokens)
if ctx_length == 0:
break

# in case of odd number of phrases (possibly due to truncation)
# since the first phrase always has to be a prompt, force it to be <bos>
if len(out) % 2 == 1:
if sum(map(len, out)) == max_length:
out[0].pop(0)
out.insert(0, [tokenizer.bos_token_id])

elif truncation_side == "right":
for phrase in dialogue:
tokens = tokenizer(phrase).input_ids[:ctx_length]
ctx_length -= len(tokens)
out.append(tokens)
if ctx_length == 0:
break
return out


@register_orchestrator
Expand All @@ -10,67 +53,81 @@ class OfflineOrchestrator(Orchestrator):
Orchestrator that creates a static dataset for offline training
"""

def __init__(self, trainer, split_token=None):
def __init__(self, trainer):
self.trainer = trainer
self.split_token = split_token

def make_experience(self, samples, rewards):
def make_experience(self, samples, rewards, max_length=2048):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the model
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
"""
if self.trainer.tokenizer:
input_ids = self.trainer.tokenize(samples)
else:
input_ids = samples

input_ids = list(map(torch.as_tensor, input_ids))

states_ixs, actions_ixs = [], []
dones = []
for s, s_tok in zip(samples, input_ids):
# split samples on (prompts, continuations) on a given substring `split_token`
if self.split_token:
prompt_str_len = s.index(self.split_token) + len(self.split_token)
prompt_tok_len = len(
self.trainer.tokenizer(s[:prompt_str_len]).input_ids
samples = [
tokenize_dialogue(
s, self.trainer.tokenizer, max_length, truncation_side="right"
)
# else assume that the prompt is a bos token
else:
prompt_tok_len = 1

# indices of continuations, to mask prompts in loss computation
a_ixs = torch.arange(prompt_tok_len - 1, len(s_tok) - 1)
# same continuations but for value computation, with the premise to eventually support interleaved dialog
s_ixs = torch.arange(prompt_tok_len - 1, len(s_tok))
# mask continuation's ending
terminals = torch.ones_like(s_ixs)
terminals[-1] = 0

actions_ixs.append(a_ixs)
states_ixs.append(s_ixs)
dones.append(terminals)
for s in samples
]

all_input_ids = []
all_actions_ixs = []
all_states_ixs = []
all_dones = []
for sample in samples:
length = 0
all_input_ids.append(torch.tensor(sum(sample, [])))
isoutput = False
actions_ixs = []
for phrase in sample:
if isoutput:
actions_ixs.append(
torch.arange(length - 1, length + len(phrase) - 1)
)

length += len(phrase)
isoutput = not isoutput

states_ixs = torch.hstack((*actions_ixs, torch.tensor(length - 1)))
all_dones.append(torch.tensor([1] * (len(states_ixs) - 1) + [0], dtype=int))
all_actions_ixs.append(torch.hstack(actions_ixs))
all_states_ixs.append(states_ixs)

if self.trainer.tokenizer:
prompt = self.trainer.tokenizer.decode(input_ids[0][: states_ixs[0][1]])
response = self.trainer.tokenizer.decode(input_ids[0][states_ixs[0][1] :])
print("[Sample example]")
print("Prompt: ", prompt)
print("Response: ", response)

print(f"[Mean reward] {torch.Tensor(rewards).mean():.2f}")
print(
f"[Mean sample length] {torch.mean(torch.Tensor(list(map(len, input_ids)))):.2f}"
)
prompt = self.trainer.tokenizer.decode(
all_input_ids[0][: all_states_ixs[0][1]]
)
response = self.trainer.tokenizer.decode(
all_input_ids[0][all_states_ixs[0][1] :]
)
print_rank_0("[Sample example]")
print_rank_0("Prompt: ", prompt)
print_rank_0("Response: ", response)
print_rank_0("Reward: ", rewards[0])

returns = torch.as_tensor(rewards, dtype=torch.float)
returns = (returns - returns.mean()) / (returns.std() + 1e-30)
sample_lengths = np.array(list(map(len, all_input_ids)))
output_lengths = np.array(list(map(len, all_actions_ixs)))
prompt_lengths = sample_lengths - output_lengths
returns = torch.tensor(rewards, dtype=float)

def string_stats(name: str, xs: np.array):
return f"[Mean {name}] {xs.mean():.2f} ∈ [{min(xs)}, {max(xs)}]"

rewards = [torch.zeros(x.shape[0]) for x in actions_ixs]
for rs, G in zip(rewards, returns):
rs[-1] = G
print_rank_0(string_stats("prompt length", prompt_lengths))
print_rank_0(string_stats("output length", output_lengths))
print_rank_0(string_stats("sample length", sample_lengths))
print_rank_0(string_stats("return", returns))

returns = (returns - returns.mean()) / (returns.std() + 1e-30)
rewards = [torch.zeros(len(x)) for x in all_actions_ixs]
for rs, ret in zip(rewards, returns):
rs[-1] = ret

attention_mask = [torch.ones(x.shape[0], dtype=int) for x in input_ids]
attention_mask = [torch.ones(len(x), dtype=int) for x in all_input_ids]

self.trainer.store = ILQLRolloutStorage(
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones
all_input_ids,
attention_mask,
rewards,
all_states_ixs,
all_actions_ixs,
all_dones,
)
58 changes: 39 additions & 19 deletions trlx/trainer/nn/ilql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from trlx.data.ilql_types import ILQLBatch
from trlx.data.method_configs import MethodConfig, register_method
from trlx.utils.modeling import (
flatten_dict,
freeze_bottom_causal_layers,
get_tensor_stats,
hf_get_causal_base_model,
hf_get_hidden_size,
hf_get_lm_head,
Expand All @@ -39,6 +41,7 @@ class ILQLConfig(MethodConfig):
cql_scale: float
awac_scale: float
alpha: float
beta: float
steps_for_target_q_sync: float
two_qs: bool
gen_kwargs: dict
Expand All @@ -48,18 +51,20 @@ def heads(self, hidden_size: int, vocab_size: int):

def loss(self, outputs, labels: ILQLBatch):
logits, (qs, target_qs, vs) = outputs
terminal_mask = labels.dones[:, :-1]
n_nonterminal = max(1, terminal_mask.sum())

actions = (
labels.input_ids[:, 1:]
.gather(dim=1, index=labels.actions_ixs)
.unsqueeze(-1)
)
bsize, ntokens, dsize = logits.shape
nactions = actions.shape[1]
bsize, _, dsize = logits.shape

Q = [q.gather(-1, actions).squeeze(-1) for q in qs]
targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs]
targetQ = reduce(torch.minimum, targetQs)
terminal_mask = labels.dones[:, :-1]
n_nonterminal = max(1, terminal_mask.sum())

# values of current states
V = vs[:, :-1].squeeze()
Expand All @@ -81,8 +86,6 @@ def loss(self, outputs, labels: ILQLBatch):
* terminal_mask
).sum() / n_nonterminal

nactions = qs[0].shape[1]

def cql_loss(q):
loss = F.cross_entropy(
q.reshape(-1, dsize), actions.reshape(-1), reduction="none"
Expand All @@ -93,24 +96,41 @@ def cql_loss(q):

loss_cql = sum(cql_loss(q) for q in qs)

loss_awac = (
F.cross_entropy(
logits[:, :-1, :].reshape(-1, dsize),
labels.input_ids[:, 1:].reshape(-1),
reduction="none",
).reshape(bsize, ntokens - 1)
* labels.attention_mask[:, 1:]
).sum() / labels.attention_mask[:, 1:].sum()
# select logits from continuations
action_logits = logits.gather(
dim=1, index=labels.actions_ixs.unsqueeze(-1).repeat(1, 1, dsize)
)
cross_entropy = F.cross_entropy(
action_logits.reshape(-1, dsize),
actions.reshape(-1),
reduction="none",
).reshape(bsize, nactions)

with torch.no_grad():
awac_weight = torch.exp(self.beta * (targetQ - V))

loss_awac = (
torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal
)
loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac

stats = {
f"losses/{k}": v
for k, v in locals().items()
if k in ["loss", "loss_v", "loss_q", "loss_cql", "loss_awac"]
}
stats = dict(
losses=dict(
loss=loss.item(),
loss_q=loss_q.item(),
loss_v=loss_v.item(),
loss_cql=loss_cql.item(),
loss_awac=loss_awac.item(),
),
values=get_tensor_stats(V, terminal_mask, n_nonterminal),
qvalues={
str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal)
for ix in range(len(Q))
},
awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal),
)

return loss, stats
return loss, flatten_dict(stats)


class ILQLHeads(nn.Module):
Expand Down
14 changes: 7 additions & 7 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def train(
eval_prompts: Optional[List[str]] = None,
metric_fn: Optional[Callable] = None,
config: Optional[TRLConfig] = None,
split_token: Optional[str] = None,
logit_mask: Optional[List[List[bool]]] = None,
):
"""
Expand All @@ -24,12 +23,15 @@ def train(
Args:
model_path (Optional[str]): Path to either huggingface checkpoint or a local directory
reward_fn (List[str] -> List[float]): Function to rate batches of generated samples
dataset (List[str], List[float]): Lists of samples and rewards
dataset (List[Union[str, List[str]]], List[float]):
Lists of samples and rewards for offline training. Samples consist of a variable number
of prompts (questions, environment states etc.) and outputs which are meant to be optimized.
Following form is expected (prompt_0: str, output_0: str, prompt_1: str, output_1: str ...).
Giving a single string `s` for the sample is a shorthand for (`tokenizer.bos_token`, `s`)
prompts (List[str]): Prompts to sample off from during online training
eval_prompts (List[str]): Prompts to periodically validate training on
metric_fn (Optional[Callable[List[str], List[float]]]): Function to compute statistics on validation samples
config (Optional[TRLConfig]): TRL configuration object to override default settings
split_token (Optional[str]): Split samples in the dataset on prompts and continuations
logit_mask (Optional[List]): Bigram masking matrix
"""
if reward_fn is not None:
Expand Down Expand Up @@ -95,10 +97,8 @@ def train(
eval_prompts, max_prompt_length, trainer.tokenizer
)

orch = get_orchestrator(config.train.orchestrator)(
trainer, split_token=split_token
)
orch.make_experience(samples, rewards)
orch = get_orchestrator(config.train.orchestrator)(trainer)
orch.make_experience(samples, rewards, config.train.seq_length)
trainer.add_eval_pipeline(eval_pipeline)

else:
Expand Down
8 changes: 8 additions & 0 deletions trlx/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from torchtyping import TensorType


def print_rank_0(*message):
"""
Print only once from the main rank
"""
if os.environ.get("RANK", "0") == "0":
print(*message)


def set_seed(seed: int):
"""
Sets seeds across package dependencies for reproducibility.
Expand Down

0 comments on commit 0cb8438

Please sign in to comment.