Skip to content

Commit

Permalink
Added reward function for dialog
Browse files Browse the repository at this point in the history
  • Loading branch information
rajcscw committed Nov 10, 2022
1 parent 21a6928 commit 19376cc
Show file tree
Hide file tree
Showing 6 changed files with 783 additions and 482 deletions.
10 changes: 7 additions & 3 deletions rl4lms/envs/text_generation/alg_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def compute_batched_rewards(
generated_texts = []
is_dones = []
indices = []
meta_infos = []
for env_ix, transitions in enumerate(episode_wise_transitions):
for trans_ix, transition in enumerate(transitions):
done = transition.done
Expand All @@ -71,16 +72,19 @@ def compute_batched_rewards(
reference_texts.append(info["reference_text"])
generated_texts.append(info["output"])
is_dones.append(done)
meta_infos.append(info["meta_info"])
indices.append((env_ix, trans_ix))

# compute rewards all at once
rewards = reward_fn(prompts, generated_texts, reference_texts, is_dones)
rewards = rewards.numpy().flatten()
rewards = reward_fn(prompts, generated_texts, reference_texts, is_dones, meta_infos)
# rewards = rewards.numpy().flatten()

# override the rewards in transitions
for (env_ix, trans_ix), reward in zip(indices, rewards):
episode_wise_transitions[env_ix][trans_ix].task_reward = reward
episode_wise_transitions[env_ix][trans_ix].total_reward = reward + episode_wise_transitions[env_ix][trans_ix].kl_reward
episode_wise_transitions[env_ix][trans_ix].total_reward = (
reward + episode_wise_transitions[env_ix][trans_ix].kl_reward
)


def wrap_onpolicy_alg(
Expand Down
103 changes: 68 additions & 35 deletions rl4lms/envs/text_generation/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@


class TextGenEnv(Env):
def __init__(self, tokenizer: AutoTokenizer,
reward_function: RewardFunction,
samples: Tuple[List[Sample], float],
max_episode_length: int = 512,
priority_scale: float = 0.0,
max_prompt_length: Optional[int] = None,
terminate_on_eos: bool = False,
context_start_token: Optional[int] = None,
prompt_truncation_side: str = "left"):
def __init__(
self,
tokenizer: AutoTokenizer,
reward_function: RewardFunction,
samples: Tuple[List[Sample], float],
max_episode_length: int = 512,
priority_scale: float = 0.0,
max_prompt_length: Optional[int] = None,
terminate_on_eos: bool = False,
context_start_token: Optional[int] = None,
prompt_truncation_side: str = "left",
):
"""
A generic RL environment to generate textual sequences.
For eg: text generation, summarization, machine translation, text simplification
Expand All @@ -39,31 +42,48 @@ def __init__(self, tokenizer: AutoTokenizer,
self.tokenizer = tokenizer
self.reward_function = reward_function
self.max_steps = max_episode_length
self._max_text_length = max_prompt_length if max_prompt_length else tokenizer.model_max_length
self._max_text_length = (
max_prompt_length if max_prompt_length else tokenizer.model_max_length
)
self._terminate_on_eos = terminate_on_eos
self._context_start_token = context_start_token
self._prompt_truncation_side = prompt_truncation_side
super().__init__()

# set the observation and action space here
self._vocab_size = tokenizer.vocab_size
self.observation_space = DictSpace({
# we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited
# while creating rollout buffers, observations are concatenated for each key
"prompt_or_input_encoded_pt": spaces.Box(low=0, high=self._vocab_size, shape=(self._max_text_length,)),
"prompt_or_input_attention_mask_pt": spaces.Box(low=0, high=1, shape=(self._max_text_length,)),
"context_encoded_pt": spaces.Box(low=0, high=self._vocab_size, shape=(self.max_steps,)),
"context_attention_mask_pt": spaces.Box(low=0, high=1, shape=(self.max_steps,)),
"input_encoded_pt": spaces.Box(low=0, high=self._vocab_size, shape=(self._max_text_length+self.max_steps,)),
"input_attention_mask_pt": spaces.Box(low=0, high=1, shape=(self._max_text_length+self.max_steps,)),
})
self.observation_space = DictSpace(
{
# we have to provide fixed sized inputs (padded) because sb3 support for DictObsersevation is limited
# while creating rollout buffers, observations are concatenated for each key
"prompt_or_input_encoded_pt": spaces.Box(
low=0, high=self._vocab_size, shape=(self._max_text_length,)
),
"prompt_or_input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length,)
),
"context_encoded_pt": spaces.Box(
low=0, high=self._vocab_size, shape=(self.max_steps,)
),
"context_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self.max_steps,)
),
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.action_space = Discrete(n=self._vocab_size)
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency
if 't5' in self.tokenizer.name_or_path:
if "t5" in self.tokenizer.name_or_path:
n = 32128
self.action_space = Discrete(n=n)
self.sampler_for_replaying = PrioritySampler(
priority_scale=priority_scale)
self.sampler_for_replaying = PrioritySampler(priority_scale=priority_scale)
for sample, weight in samples:
self.sampler_for_replaying.add(sample, weight)

Expand All @@ -88,13 +108,23 @@ def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
self.__current_obs = self.__current_obs.update(action, self.tokenizer)

# decide if the episode is finished or not
done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or (self.__time_step ==
self.max_steps)
done = (action == self.tokenizer.eos_token_id and self._terminate_on_eos) or (
self.__time_step == self.max_steps
)

# compute reward
if not isinstance(self.reward_function, BatchedRewardFunction):
reward = None if self.reward_function is None else self.reward_function(
previous_obs, action, self.__current_obs, done, self.__current_obs.meta_info)
reward = (
None
if self.reward_function is None
else self.reward_function(
previous_obs,
action,
self.__current_obs,
done,
self.__current_obs.meta_info,
)
)
else:
reward = -inf # will be overridden later

Expand All @@ -104,7 +134,8 @@ def step(self, action: int) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
"action_history": self.__current_obs.action_history,
"reference_text": self.__current_obs.target_or_reference_texts,
"prompt_text": self.__current_obs.prompt_or_input_text,
"prev_output": previous_obs.context_text
"prev_output": previous_obs.context_text,
"meta_info": previous_obs.meta_info,
}

return self.__current_obs.to_dict(), reward, done, info
Expand All @@ -119,13 +150,15 @@ def reset(self, sample: Sample = None) -> Dict[str, torch.tensor]:
self.__current_sample = sample

# init the observation
self.__current_obs = Observation.init_from_sample(sample,
self.tokenizer,
self._max_text_length,
self.max_steps,
self._prompt_truncation_side,
self._context_start_token,
sample.meta_data)
self.__current_obs = Observation.init_from_sample(
sample,
self.tokenizer,
self._max_text_length,
self.max_steps,
self._prompt_truncation_side,
self._context_start_token,
sample.meta_data,
)

# start the time step counter
self.__time_step = 0
Expand Down
Loading

0 comments on commit 19376cc

Please sign in to comment.