Fork of the official repository for the paper "Authorship Style Transfer with Policy Optimization". This fork is customized to support the Yerevan 2024 Summer School.
Log in to your node tunneling port 8080 so you can monitor using WandB:
ssh -L 8080:localhost:8080 -t <machine>
Commends for enviroment setup with conda.
conda create --name astrapop python=3.8
conda activate astrapop
pip install -U pip
pip install -r requirements.txt
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
Please see instructors for a link to the data. Create a directory called data
and unpack the provided tarballs in there.
Get access to Llama2 (or a similar model you want to be your backbone) by filling out the form at https://huggingface.co/meta-llama/Llama-2-7b-hf.
Obtain an access token
export HUGGINGFACE_ACCESS_TOKEN=<your token>
Install wandb and set up a local server; follow instructions through step 2 here: https://docs.wandb.ai/guides/hosting/self-managed/basic-setup
It is recommended to reproduce the ETS results. Only two languages of the original eleven are used, to save time. The scripts that run the original eleven are in scripts/ets/orig
.
To reproduce the results on the ETS dataset, please run the scirpts in scripts/ets
.
- Train the style reward model, the paraphrase model, and the reference SFT model by running
00_train_cls.sh
,00_train_paraphraser.sh
, and00_train_sft.sh
. - Generate the data for DPO and CPO training by running
01_generate_dpo_cpo_data.sh
. - Train the PO models using PPO/DPO/CPO by running
02_train_ppo.sh
/02_train_dpo.sh
/02_train_cpo.sh
. - Transfer the texts in the test set by running
03_generate.sh
.
Here is the information for Reddit, for those interested.
To reproduce the results on the Reddit dataset, please run the scirpts in scripts/reddit
following the procedure below.
- Train the paraphrase model and the reference SFT model by running
00_train_paraphraser.sh
and00_train_sft.sh
. - Generate the data for DPO and CPO training by running
01_generate_dpo_cpo_data.sh
. - Train the PO models using PPO/DPO/CPO by running
02_train_ppo.sh
/02_train_dpo.sh
/02_train_cpo.sh
. - Transfer the texts in the test set by running
03_generate.sh
.