Directional-Stimulus-Prompting is a framework that uses a tuneable language model (LM) to provide guidance for the black-box frozen large language model (LLM) towards desirable properties. Specifically, we train a policy LM to generate discrete tokens as directional stimulus of each input, which is a hint/cue such as keywords of an article for summarization. The directional stimulus is then combined with the original input and fed into the LLM to guide its generation toward the desired target (an example can be seen in Figure 1).
Figure 1: Comparison of our proposed Directional Stimulus Prompting with the standard prompting method to use the LLM such as GPT-3 on the summarization task. Our DSP uses a tuneable policy LM to generate the stimulus (highlighted in orange color), which is keywords in this case, to guide the LLM on generating the desired summary (highlighted in blue color) with higher rouge scores or other measures like human preference.
The policy LM can be trained through (1) supervised finetuning from annotated data (SFT)
and (2) reinforcement learning from offline and online rewards (RL)
to explore directional stimulus that better aligns LLMs with human preferences. This framework is flexibly applicable to various LMs and tasks. An illustration of the DSP framework is shown in Figure 2.
Paper Link: https://arxiv.org/abs/2302.11520
Figure 2: Overview of our proposed framework DSP, which learns a small policy LM to improve the frozen LLM's performance on specific downstream tasks. Given the input, the policy LM generates stimulus to guide the LLM's generation, which is then evaluated with downstream performance measures or human labelers. The evaluation scores are used as rewards to optimize the policy LM with RL. The parameters of LLM are frozen while the policy LM is tuneable.
Currently, we test the framework on two benchmark tasks:
- Summarization
- Dialogue Generation
Our code is based on RL4LMs. Users can customize the dataset, metrics, and LLM-based reward function to train transformer-based policy LMs, to provide guidance for the LLMs towards the desirable properties.
git clone https://github.com/leezekun/Directional-Stimulus-Prompting.git
cd Directional-Stimulus-Prompting
pip install -e .
We provide also a Dockerfile for development using docker containers containing all the dependencies.
docker build . -t dsp
Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh
You should setup your openai access key to call the api.
export OPENAI_API_KEY='XXXXXXXX'
First, we perform supervised finetuning (SFT) on the policy LM with annotated data to provide a good initial point for the further RL training. The code and data are placed in the sft4lms
directory. We provide the script to run the SFT for the two tasks:
sh run_sft_cnndm.sh # for the summarization task on the CNN/Daily Mail dataset
sh run_sft_multiwoz.sh # for the dialogue generation task on the MultiWOZ dataset
This part is based on RL4LMs. A simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).
We provide the scripts of training the policy LM T5 on the tasks of summarization and dialogue generation. You can run the scripts:
sh run_ppo_cnndm.sh
sh run_ppo_multiwoz.sh
The config files for the summarization and dialogue generation tasks can be found in the scripts/training/task_configs/summarization_with_hint
and scripts/training/task_configs/multiwoz_with_hint
, respectively.
You can customize the configuration files as instructed in RL4LMs.
Config file contains details about hyper-parameter settings for building blocks which are described below:
-
Dataset/Task: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class
DataPoolRegistry
in registry. (See how to create your own dataset here) For our experiments, we customize the datasets of CNN/Daily Mail and MultiWOZ, which are registered ascnn_daily_mail_with_hint
andmultiwoz_with_hint
:datapool: id: cnn_daily_mail_with_hint args: prompt_prefix: "Extract the keywords: " n_train: 2000 n_val: 500 n_test: 500 extraction_mode: "textrank" extraction_source: "all"
datapool: id: multiwoz_with_hint args: version: "2.0" n_train: 80 n_val: 100 n_test: 1000
-
Reward Function: Reward function which computes token-level scores at each time step of MDP. Available reward functions can be found in the class
RewardFunctionRegistry
. (See how to create your own reward function here) We customize the LLM-based reward functions, where the reward is measured on the generation of LLMs guided by stimulus generated by the trained policy LM.reward_fn: id: summarization_with_hint args: gpt3_model: 'gpt-3.5-turbo' interval: 0.5 # arguments for exponential backoff timeout: 20.0 exp: 2.0 patience: 10 temperature: 0.7 # arguments for the LLM's inference max_tokens: 128 num_seqs: 4 top_p: 1.0 stop_words: ["Article:", "Q:", "A:", "<|im_end|>"] selection_strategy: "choose_all" # average all the inferences generated by the LLM prompt_prefix: "Extract the keywords: " prompt_path: "./prompts/cnn_fs.txt" hint_prompt_path: "./prompts/cnn_hint_fs.txt" gpt3_metric: "rouge-avg" # metric on the generation of the LLM gpt3_coef: 10. use_baseline: False t5_coef: 0. t5_metric: "hint_hit" # the customized metric on the keywords generated by the policy LM (t5) t5_pos_coef: 1.0 t5_neg_coef: 0.25 # penalty for the policy LM (t5) if generated a ``wrong'' keyword step_reward_coef: 1.0 # set as 0 if not use step reward split_token: ";" # we use ";" to split multiple keywords split_token_id: 117 # token id of ";" for t5
Note that we conducted the experiments using Codex (gpt-3.5-turbo), which has not been supported by OPENAI since March 23rd, 2023. However, you can apply for either the Codex model access or a research subsidy.
You can also try other models by changing the gpt3_model
.
-
Environment: Configures a gym-style text generation environment which simulates MDP episodes. Rollouts are generated using train samples from dataset consisting of input and reference texts. Further, we wrap our env with
SubProcVecEnv
from stable-baselines that processesn_envs
episodes in parallel using multi-processing to compute step-wise rewards.
Further configuration settings include:max_episode_length
: max length of the episodemax_prompt_length
- maximum length of the input text to considerterminate_on_eos
- whether to terminate the episode as soon as EOS action is performedprompt_truncation_side
- truncation side for the prompt textcontext_start_token
- id for context token (corresponds to initial token given to decoder in encoder-decoder models)
env: n_envs: 10 args: max_prompt_length: 512 max_episode_length: 100 terminate_on_eos: True prompt_truncation_side: "right" context_start_token: 0
-
On-policy alg: We provide implementations of 4 on-policy algorithms: PPO, NLPO, A2C and TRPO adapted from stable-baselines3 tailored to work with NLP tasks which can be used out-of-the-box with either a causal policy or a seq2seq LM policy. (See how to create your own on-policy algorithm or policy)
-
We also provide a supervised trainer for benchmarking purposes. Supervised Warm start models are already uploaded to Huggingface Hub and specified in the respective config files.
-
Hyper-parameters for the algorithm can be specified at
alg/args
. -
Further, all RL algorithms use adaptive KL controller to keep the LM close to original LM by setting initial KL co-efficient (
alg/kl_div/coeff
) and target KL (alg/kl_div/target_kl
). -
We support two types of LM policy: causal LM policy (for decoder only models) and seq2seq LM policy (for encoder-decoder models). Further for NLPO, we also provide maskable variants of these. Policy implementations can be found here in and it can be attached to algorithms by specifying
alg/policy/id
andalg/policy/args
alg: id: nlpo args: n_steps: 512 batch_size: 1 verbose: 1 learning_rate: 0.000002 n_epochs: 5 ent_coef: 0.0 vf_coef: 0.5 kl_div: coeff: 0.005 target_kl: 0.5 policy: id: maskable_seq2seq_lm_actor_critic_policy args: model_name: $MODEL_PATH # the initial checkpoint of the policy LM, use t5-base or the checkpoints trained with SFT in the first step apply_model_parallel: True prompt_truncation_side: "right" min_tokens_to_keep: 100 top_mask: 0.9 mask_type: "learned_top_p" target_update_iterations: 20 generation_kwargs: min_length: 8 max_new_tokens: 64 do_sample: True top_k: 100
-
-
Trainer Config: We provide an On-policy trainer - a feature-complete wrapper that instantiates building blocks from their corresponding configs and provides an outer training loop consisting of train and eval iterations
train_evaluation/n_iters
.- Each iteration corresponds to performing updates with
alg/args/n_steps
xenv/n_envs
of the chosen algorithm. - For every
eval_every
iters, LM is evaluated on validation split using metrics listed intrain_evaluation/metrics
with generation kwargs provided intrain_evaluation/generation_kwargs
(this overrides rolloutalg/policy/generation_kwargs
for inference purposes only)
We customize the evaluation function, which measures on the generation of the LLM and the trained policy LM T5.
# train and evaluation train_evaluation: eval_batch_size: 10 n_iters: 20 eval_every: 2 save_every: 2 metrics: - id: summarization_with_hint args: gpt3_model: 'gpt-3.5-turbo' interval: 0.5 timeout: 20.0 exp: 2 patience: 10 temperature: 0.7 max_tokens: 128 num_seqs: 3 top_p: 1.0 stop_words: ["Article:", "Q:", "A:"] selection_strategy: "choose_all" split_token: ";" split_token_id: 117 # token id of t5 for ";" prompt_prefix: "Extract the keywords: " prompt_path: "./prompts/cnn_fs.txt" hint_prompt_path: "./prompts/cnn_hint_fs.txt" use_lower_baseline: False use_upper_baseline: False gpt3_metrics: - id: meteor args: {} - id: rouge args: use_single_ref: False - id: bleu args: {} - id: bert_score args: language: en t5_metrics: - id: "hint_hit" args: split: ";" generation_kwargs: # for the trained policy LM T5 min_length: 8 max_new_tokens: 64 do_sample: True top_k: 0 temperature: 0.7
- Each iteration corresponds to performing updates with
RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.
Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool':
method to return an instance of TextGenPool. An example is shown below:
from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool
class MyDataPool(TextGenPool):
@classmethod
def prepare(cls, split: str):
..
samples = []
for ix, item in enumerate(..):
sample = Sample(id=f"{split}_{ix}",
prompt_or_input_text=item["document"],
references=[item["target"]]
)
samples.append(sample)
pool_instance = cls(samples)
return pool_instance
Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation (
from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction
class MyRewardFunction(RewardFunction):
def __init__(self, *args) -> None:
super().__init__()
def __call__(self, prev_observation: Observation,
action: int,
current_observation: Observation,
done: bool,
meta_info: Dict[str, Any] = None) -> float:
if done:
reward = ..
return reward
return 0
π‘ In addition to traditional NLG metrics, for quick prototyping, we provide two synthetic reward functions which trains LMs to generate numbers in increasing order and generate dates. These can be used to quickly test different algorithms and policies. Corresponding configs can be found here (numbers, dates)
Users can create their own evaluation metric which then will be used to periodically evaluate the model on validation split of dataset. This can be done by sub-classing BaseMetric which takes prompt texts, generated texts, reference texts, meta_infos, current LM model, split name as inputs and returns a dict with metric name as key and value consisting of tuple of sentence-level scores and corpus level scores. An example is as follows:
from rl4lms.envs.text_generation.metric import BaseMetric
class MyMetric(BaseMetric):
def __init__(self) -> None:
super().__init__()
def compute(self,
prompt_texts: List[str],
generated_texts: List[str],
reference_texts: List[List[str]],
meta_infos: List[Dict[str, Any]] = None,
model: PreTrainedModel = None,
split_name: str = None):
metric_dict = {
"custom_metrics/my_metric": ([0.4, 0.7, 0.9], 0.7)
}
return metric_dict
In addition to supported on-policy algorithms (PPO, NLPO, A2C,TRPO), users can implement their own on-policy algorithms with ease by sub-classing stable-baselines3's OnPolicyAlgorithm. Since we provide wrappers for on-policy algorithms that handles rollouts using LM policies, environment, computing rewards etc, users just need to implement train()
method with custom loss functions.
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
class MyOnPolicyAlgorithm(OnPolicyAlgorithm):
def __init__(**args):
super().__init__(**args)
def train(self) -> None:
# train for n_epochs epochs
for epoch in range(self.n_epochs):
# Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size):
# compute loss
We provide LM based actor-critic policy implementations that wraps causal LM and seq2seq LMs. These can be also extended (for eg: use a different critic architecture) by overriding appropriate methods (eg. evaluate_actions()
)
Finally, just register your custom components by adding them to corresponding registry, after which they can be used directly from configs similar to pre-defined components π
We have provided the crowdsourcing templates we used on mechanical turk, along with example inputs in scripts/crowdworking_templates
. You might find these a helpful starting point either for evaluating your own model's generations, or for gathering training data for a learned reward function.
Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path. This is especially useful for running preemptible jobs on large, scheduled clusters.
Artifacts include (1) jsonl file containing rollout infos at specified intervals (2) jsonl file containing training infos at specified intervals (3) jsonl file containing validation metrics at specified intervals (4) jsonl file containing test metrics before and after training (5) json file with validation predictions at specified intervals (6) json file with test predictions before and after training (7) trained LM model (8) config json used to run the experiment
Complete usage is as follows:
WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE> python scripts/training/train_text_generation.py \
--config_path <PATH-TO-CONFIG-FILE> \
--experiment_name <EXPERIMENT-NAME> \
--base_path_to_store_results <PATH-TO-STORE-RESULTS> \
--log_to_wandb
@article{li2023guiding,
title={Guiding Large Language Models via Directional Stimulus Prompting},
author={Li, Zekun and Peng, Baolin and He, Pengcheng and Galley, Michel and Gao, Jianfeng and Yan, Xifeng},
journal={arXiv preprint arXiv:2302.11520},
year={2023}
}
We thank the authors of RL4LMs for sharing their code. You can contact Zekun Li ([email protected]
), if there are questions related to the code.