Skip to content

Commit

Permalink
add finetuning code (without rl)
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhungle committed Sep 26, 2022
1 parent b3e2ac9 commit cf71699
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 10 deletions.
File renamed without changes.
16 changes: 8 additions & 8 deletions datasets/apps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, dataroot, problem_dirs, model, max_tokens, sample_mode,
self.all_error_types, self.all_error_subtypes = [], []
self.initialize()

if self.model in ['codet5-base']:
if self.model in ['codet5-base', 'codet5-large']:
self.tokenizer = transformers.RobertaTokenizer.from_pretrained('Salesforce/codet5-base')

def load_gen_samples(self, sols, answer_type, starter_code, question_str):
Expand Down Expand Up @@ -196,7 +196,7 @@ def pack_samples(self, idx, sample_type=None):
curr_samples.append((curr_q, curr_s, curr_a, curr_q_prefix))

# only pack 1 sample each sequence for codeT5
if self.model in ['codet5-base']:
if self.model in ['codet5-base', 'codet5-large']:
break

if self.sample_mode == 'uniform_sol':
Expand All @@ -215,7 +215,7 @@ def __getitem__(self, idx):

else:
raw_samples = self.pack_samples(idx)
inputs = self.sample_gpt_task(raw_samples)
inputs = self.sample_task(raw_samples)

gc.collect()
return inputs
Expand All @@ -240,7 +240,7 @@ def sample_task(self, samples, sample_type=None):
input_ids.extend(question_token_ids)

answer_token_ids = self.tokenizer.encode(a_str, verbose=False)
if self.model not in ['codet5-base']:
if self.model not in ['codet5-base', 'codet5-large']:
label_ids.extend([-100] * len(question_token_ids))
answer_token_ids.append(self.tokenizer.eos_token_id)
input_ids.extend(answer_token_ids)
Expand All @@ -250,23 +250,23 @@ def sample_task(self, samples, sample_type=None):
error_types.append(dsutils.get_error_type(result))

# Sanity checks and padding
input_ids_max_len = self.max_src_tokens if self.model in ['codet5-base'] else self.max_tokens
input_ids_max_len = self.max_src_tokens if self.model in ['codet5-base', 'codet5-large'] else self.max_tokens
if len(input_ids) < input_ids_max_len:
new_input_ids = [self.tokenizer.eos_token_id] * input_ids_max_len
new_input_ids[:len(input_ids)] = input_ids
input_ids = new_input_ids

if self.model not in ['codet5-base']:
if self.model not in ['codet5-base', 'codet5-large']:
new_label_ids = [-100] * input_ids_max_len
new_label_ids[:len(label_ids)] = label_ids
label_ids = new_label_ids

if self.model in ['codet5-base'] and len(label_ids) < self.max_tokens:
if self.model in ['codet5-base', 'codet5-large'] and len(label_ids) < self.max_tokens:
new_label_ids = [-100] * self.max_tokens
new_label_ids[:len(label_ids)] = label_ids
label_ids = new_label_ids

if self.model not in ['codet5-base'] and len(input_ids) != len(label_ids): pdb.set_trace()
if self.model not in ['codet5-base', 'codet5-large'] and len(input_ids) != len(label_ids): pdb.set_trace()

if self.tuning_mode in ['critic'] and sample_type == 'gen':
assert len(error_types) == 1
Expand Down
17 changes: 17 additions & 0 deletions scripts/train_actor.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

# Run code in debugging mode (without deepspeed)
python \
train.py \
--batch-size-per-replica 1 --grad-acc-steps 4 \
--epochs 10 --lr 2e-5 \
--save-freq 1000 --log-freq 10 --save_total_limit 5 \
--fp16 \
--tuning_mode none --model codet5-large \
--db

17 changes: 17 additions & 0 deletions scripts/train_actor_deepspeed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

# Run code with deepspeed
USE_TF=NO deepspeed --master_port 62000 \
train.py \
--batch-size-per-replica 1 --grad-acc-steps 4 \
--epochs 10 --lr 2e-5 \
--save-freq 1000 --log-freq 10 --save_total_limit 5 \
--fp16 \
--tuning_mode none --model codet5-large \
--deepspeed configs/deepspeed_configs.json

2 changes: 1 addition & 1 deletion scripts/train_critic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Run code in debugging mode (without deepspeed)
python \
train_critic.py \
train.py \
--batch-size-per-replica 8 --grad-acc-steps 1 \
--epochs 10 --lr 2e-5 \
--save-freq 1000 --log-freq 10 --save_total_limit 5 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_critic_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Run code with deepspeed
USE_TF=NO deepspeed --master_port 62000 \
train_critic.py \
train.py \
--batch-size-per-replica 8 --grad-acc-steps 1 \
--epochs 10 --lr 2e-5 \
--save-freq 1000 --log-freq 10 --save_total_limit 5 \
Expand Down
146 changes: 146 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

import io
import logging
import math
import os
import pprint
import sys
import time
import json
import pdb
from tqdm import tqdm
from datetime import datetime

import transformers
import torch

from datasets.apps_dataset import APPSBaseDataset
from trainers.trainer_critic import Trainer_Critic
from transformers import Trainer

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')


def run_training(args, train_data):
if args.model in ['codet5-base', 'codet5-large']:
model_path = args.model_path if args.model_path is not None else 'Salesforce/{}'.format(args.model)
print("Loading model from {}...".format(model_path))
model = transformers.T5ForConditionalGeneration.from_pretrained(
model_path,
tuning_mode=args.tuning_mode)
print('Finished loading model {}'.format(args.model))

start_iteration = 0
train_data.start_iteration = start_iteration
print(f"Starting main loop")

training_args = transformers.TrainingArguments(
output_dir=args.save_dir,
overwrite_output_dir=True,

do_train=True,
do_eval=False,
do_predict=True,
evaluation_strategy='no',
eval_steps=0,

num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size_per_replica,
gradient_accumulation_steps=args.grad_acc_steps,

learning_rate=args.lr,
weight_decay=0.05,
lr_scheduler_type='constant_with_warmup',

logging_dir=args.save_dir,
logging_first_step=True,
logging_steps=args.log_freq,
save_steps=args.save_freq,
save_total_limit=args.save_total_limit,

dataloader_drop_last=True,
dataloader_num_workers=0 if args.db else 8,

local_rank=args.local_rank,

deepspeed=args.deepspeed,
fp16=args.fp16,

)

if args.tuning_mode in ['critic']:
trainer = Trainer_Critic(
model=model,
args=training_args,
train_dataset=train_data,
tuning_mode=args.tuning_mode,
)
else:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
)

trainer.train()

if args.local_rank == 0:
model.save_pretrained(os.path.join(args.save_dir, "final_checkpoint"))


def get_dataset(args):

fnames = os.listdir(args.train_path)

# train in debugging mode with small data split
if args.db:
fnames = fnames[:50]

if args.model in ['codet5-base', 'codet5-large']:
max_tokens = 512
max_src_tokens = 600
else:
max_tokens = 1024
max_src_tokens = -1

train_data = APPSBaseDataset(
dataroot=args.train_path,
problem_dirs=fnames,
model=args.model,
max_tokens=max_tokens,
max_src_tokens=max_src_tokens,
sample_mode=args.sample_mode,
tuning_mode=args.tuning_mode,
)

return train_data


def main(args):

argsdict = vars(args)
print(pprint.pformat(argsdict))

os.makedirs(args.save_dir, exist_ok=True)

# Load dataset
train_data = get_dataset(args)

# Save args to file
json.dump(argsdict, open(os.path.join(args.save_dir, "args.json"), 'w'))

# Load and train model; save model checkpoints
run_training(args, train_data)


if __name__ == "__main__":
from configs.train_configs import *

main(args)

0 comments on commit cf71699

Please sign in to comment.