Skip to content

Commit

Permalink
Merge branch 'fix-seq2seq' of github.com:allenai/RL4LMs into fix-seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
rajcscw committed Nov 10, 2022
2 parents d94dfef + 3430e24 commit 0acc232
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions rl4lms/envs/text_generation/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ToTTo,
WMT14PreprocessedEnDe,
WMT16NewsOnlyDatasetEnDe,
DailyDialog
)
from rl4lms.data_pools.text_generation_pool import TextGenPool
from rl4lms.envs.text_generation.alg_wrappers import wrap_onpolicy_alg
Expand Down Expand Up @@ -94,6 +95,7 @@ class DataPoolRegistry:
"wmt16newsonly": WMT16NewsOnlyDatasetEnDe,
"iwslt2017en_de": IWSLT2017EnDe,
"crd3": CRD3DialogueGeneration,
"daily_dialog": DailyDialog
}

@classmethod
Expand Down
50 changes: 50 additions & 0 deletions scripts/training/task_configs/dialog/gpt2_supervised.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
tokenizer:
model_name: gpt2
padding_side: left
truncation_side: left
pad_token_as_eos_token: True
max_length: 64

datapool:
id: "daily_dialog"
args:
context_size: 5

alg:
id: supervised
training_args:
per_device_train_batch_size:
expand: True
values: [32]
logging_steps: 200
num_train_epochs: 10
lr_scheduler_type: "constant"
learning_rate:
expand: True
values: [0.00001]
save_total_limit: 1
model_type: causal
model_name: gpt2
generation_kwargs:
do_sample: True
min_length: 10
max_new_tokens: 50
post_processing_fn: null

train_evaluation:
eval_batch_size: 256
metrics:
- id: diversity
args: {}
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
- id: sacre_bleu
args:
tokenize: "intl"

0 comments on commit 0acc232

Please sign in to comment.