Skip to content

Commit

Permalink
Refactored policy classes and added data classes for outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
rajcscw committed Oct 19, 2022
1 parent ac9eecf commit 6f96c4a
Show file tree
Hide file tree
Showing 13 changed files with 1,169 additions and 175 deletions.
13 changes: 6 additions & 7 deletions rl4lms/algorithms/nlpo/nlpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from rl4lms.algorithms.common.maskable.utils import get_action_masks, is_masking_supported
from rl4lms.algorithms.nlpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.policy.base_policy import EvaluateActionsOutput


class NLPO(OnPolicyAlgorithm):
Expand Down Expand Up @@ -340,11 +341,9 @@ def train(self) -> None:
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()

values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
action_masks=rollout_data.action_masks,
)
evaluation_output: EvaluateActionsOutput = self.policy.evaluate_actions(
rollout_data.observations, actions, action_masks=rollout_data.action_masks)
values, log_prob, entropy = evaluation_output.values, evaluation_output.log_prob, evaluation_output.entropy

values = values.flatten()
# Normalize advantage
Expand All @@ -357,9 +356,9 @@ def train(self) -> None:
ratio = th.exp(log_prob - rollout_data.old_log_prob)
if batch_ix == 0 and epoch == 0:
assert th.allclose(th.mean(ratio), th.tensor(
1.0), atol=1e-3), f"Ratio is {th.mean(ratio)}"
1.0), atol=1e-3), "Cannot reconstruct probability distribution. Please check your policy network implementation"

assert th.allclose(values, rollout_data.old_values, atol=1e-3)
assert th.allclose(values, rollout_data.old_values, atol=1e-3), "Cannot reconstruct values. Please check your value network implementation"

# clipped surrogate loss
policy_loss_1 = advantages * ratio
Expand Down
8 changes: 5 additions & 3 deletions rl4lms/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.policy.base_policy import EvaluateActionsOutput


class PPO(OnPolicyAlgorithm):
Expand Down Expand Up @@ -211,8 +212,9 @@ def train(self) -> None:
if self.use_sde:
self.policy.reset_noise(self.batch_size)

values, log_prob, entropy = self.policy.evaluate_actions(
evaluation_output: EvaluateActionsOutput = self.policy.evaluate_actions(
rollout_data.observations, actions)
values, log_prob, entropy = evaluation_output.values, evaluation_output.log_prob, evaluation_output.entropy
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
Expand All @@ -224,9 +226,9 @@ def train(self) -> None:
ratio = th.exp(log_prob - rollout_data.old_log_prob)
if batch_ix == 0 and epoch == 0:
assert th.allclose(th.mean(ratio), th.tensor(
1.0), atol=1e-3), f"Ratio is {th.mean(ratio)}"
1.0), atol=1e-3), "Cannot reconstruct probability distribution. Please check your policy network implementation"

assert th.allclose(values, rollout_data.old_values, atol=1e-3)
assert th.allclose(values, rollout_data.old_values, atol=1e-3), "Cannot reconstruct values. Please check your value network implementation"

# clipped surrogate loss
policy_loss_1 = advantages * ratio
Expand Down
276 changes: 178 additions & 98 deletions rl4lms/envs/text_generation/alg_wrappers.py

Large diffs are not rendered by default.

86 changes: 50 additions & 36 deletions rl4lms/envs/text_generation/evaluation_utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
from typing import Any, Dict, List

from stable_baselines3.common.policies import BasePolicy
from tqdm import tqdm
from transformers import AutoTokenizer

from rl4lms.data_pools.custom_text_generation_pools import Sample
from rl4lms.envs.text_generation.metric import BaseMetric
from rl4lms.envs.text_generation.logging_utils import Tracker
from typing import List, Dict, Any
from tqdm import tqdm
from rl4lms.envs.text_generation.metric import BaseMetric


def get_batch(samples: List[Sample], batch_size: int):
current_ix = 0
n_samples = len(samples)
while current_ix < n_samples:
current_batch = samples[current_ix:current_ix+batch_size]
current_batch = samples[current_ix : current_ix + batch_size]
yield current_batch
current_ix += batch_size


def evaluate_on_samples(policy: BasePolicy,
tokenizer: AutoTokenizer,
samples: List[Sample],
batch_size: int,
max_prompt_length: int,
metrics: List[BaseMetric],
epoch: int,
split_name: str,
tracker: Tracker = None,
dt_control_token: str = '',
gen_kwargs: Dict[str, Any] = None,
):
def evaluate_on_samples(
policy: BasePolicy,
tokenizer: AutoTokenizer,
samples: List[Sample],
batch_size: int,
max_prompt_length: int,
metrics: List[BaseMetric],
epoch: int,
split_name: str,
tracker: Tracker = None,
dt_control_token: str = "",
gen_kwargs: Dict[str, Any] = None,
):
# generate text by batch
all_generated_texts = []
all_ref_texts = []
Expand All @@ -37,8 +39,8 @@ def evaluate_on_samples(policy: BasePolicy,
n_samples = len(samples)
for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"):
batch_generated_texts = generate_text(
policy, tokenizer, batch, max_prompt_length,
dt_control_token, gen_kwargs)
policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs
)
batch_ref_texts = [sample.references for sample in batch]
batch_prompt_texts = [sample.prompt_or_input_text for sample in batch]
batch_meta_infos = [sample.meta_data for sample in batch]
Expand All @@ -53,8 +55,13 @@ def evaluate_on_samples(policy: BasePolicy,
if metrics is not None:
for metric in metrics:
metric_dict = metric.compute(
all_prompt_texts, all_generated_texts, all_ref_texts,
all_meta_infos, policy.get_language_model(), split_name)
all_prompt_texts,
all_generated_texts,
all_ref_texts,
all_meta_infos,
policy.get_language_model(),
split_name,
)

for metric_key, (sample_scores, corpus_score) in metric_dict.items():
if sample_scores is None:
Expand All @@ -64,14 +71,20 @@ def evaluate_on_samples(policy: BasePolicy,

# aggregate sample metric scores
sample_predictions_dict = []
for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate(zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts)):
for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate(
zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts)
):
sample_prediction = {
"split_name": split_name,
"sample_id": sample.id,
"prompt_text": prompt_text,
"generated_text": generated_text,
"ref_text": "".join([f"<START-{ref_ix+1}>"+ref_text+f"<END-{ref_ix+1}>"
for ref_ix, ref_text in enumerate(ref_texts)]),
"ref_text": "".join(
[
f"<START-{ref_ix+1}>" + ref_text + f"<END-{ref_ix+1}>"
for ref_ix, ref_text in enumerate(ref_texts)
]
),
}
for metric_key, sample_scores in sample_scores_by_metric.items():
sample_prediction[metric_key] = sample_scores[ix]
Expand All @@ -84,17 +97,18 @@ def evaluate_on_samples(policy: BasePolicy,
tracker.log_metrics(epoch, split_name, corpus_level_metrics)


def generate_text(policy: BasePolicy,
tokenizer: AutoTokenizer,
samples: List[Sample],
max_prompt_length: int,
dt_control_token: str,
gen_kwargs: Dict[str, Any]
):
prompt_texts = [dt_control_token +
sample.prompt_or_input_text for sample in samples]
generated_texts = policy.generate(tokenizer,
prompt_texts,
max_prompt_length,
gen_kwargs=gen_kwargs)["gen_texts"]
def generate_text(
policy: BasePolicy,
tokenizer: AutoTokenizer,
samples: List[Sample],
max_prompt_length: int,
dt_control_token: str,
gen_kwargs: Dict[str, Any],
):
prompt_texts = [
dt_control_token + sample.prompt_or_input_text for sample in samples
]
generated_texts = policy.generate(
tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs
).gen_texts
return generated_texts
5 changes: 3 additions & 2 deletions rl4lms/envs/text_generation/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,10 @@ def forward_policy(self, obs: TensorDict,
# get log probs
dist = self._action_dist.proba_distribution(
action_logits=next_token_logits)
raw_log_probs = dist.log_prob(actions)
if action_masks is not None:
dist.apply_masking(action_masks)
log_prob = dist.log_prob(actions)
log_probs = dist.log_prob(actions)
entropy = dist.entropy()

# update the model kwargs for further generation
Expand All @@ -1042,7 +1043,7 @@ def forward_policy(self, obs: TensorDict,
)
model_kwargs["decoder_attention_mask"] = torch.cat(
(decoder_attn_mask, torch.ones(batch_size, 1).to(decoder_attn_mask.device)), dim=-1)
return actions, log_prob, entropy, outputs, action_masks, model_kwargs
return actions, raw_log_probs, log_probs, entropy, outputs, action_masks, model_kwargs

def forward_value(self, obs: TensorDict,
model_kwargs: Optional[Dict[str, torch.tensor]] = None):
Expand Down
Empty file.
Loading

0 comments on commit 6f96c4a

Please sign in to comment.