Skip to content

Commit

Permalink
feat: train and clean data
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Mar 25, 2023
1 parent 2568d94 commit 723a50b
Show file tree
Hide file tree
Showing 7 changed files with 481 additions and 0 deletions.
71 changes: 71 additions & 0 deletions clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import glob
import os
import json
import jsonlines
import pandas as pd


prompt_generation_dir = "prompts-reponses"
for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")):
data = []
print(file)
with open(file) as f:
for line in f:
try:
contents = json.loads(line)
data.append(contents)
except BaseException:
pass

processed = []

for item in data:
if 'source' not in item:
item['source'] = 'unspecified'
if 'model_settings' in item:
item.pop('model_settings', None)

for key in list(item.keys()):
if key not in ['source', 'prompt', 'response']:
#print(item[key])
item.pop(key, None)

if isinstance(item['prompt'], dict):
if "value" in item["prompt"]:
item["prompt"] = item["prompt"]["value"]
elif "description" in item["prompt"]:
item["prompt"] = item["prompt"]["description"]
else:
continue

elif not isinstance(item['prompt'], str):
continue

if isinstance(item['response'], dict):
if "value" in item["response"]:
item["response"] = item["response"]["value"]
elif "description" in item["response"]:
item["response"] = item["response"]["description"]
else:
continue
elif not isinstance(item['response'], str):
continue

if item:
processed.append(item)

df = pd.DataFrame(processed)
prev_len = len(df)

# drop empty or null string
df = df.dropna(subset=['prompt', 'response'])
df = df[df['prompt'] != '']
df = df[df['response'] != '']
curr_len = len(df)

print(f"Removed {prev_len - curr_len} rows")

clean_name = file.split(".jsonl")[0] + "_clean.jsonl"
print(f"writing to {clean_name}")
df.to_json(clean_name, orient="records", lines=True)
48 changes: 48 additions & 0 deletions configs/deepspeed/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"train_micro_batch_size_per_gpu": "auto",
"fp16": {
"enabled": "auto",
"min_loss_scale": 1,
"loss_scale_window": 1000,
"hysteresis": 2,
"initial_scale_power": 32
},
"bf16": {
"enabled": "auto"
},
"gradient_clipping": 1,
"zero_optimization": {
"stage": 2,
"offload_param": {
"device": "none"
},
"offload_optimizer": {
"device": "none"
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"contiguous_gradients": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}
28 changes: 28 additions & 0 deletions configs/train/finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
gradient_checkpointing: true

# dataset
streaming: false
num_proc: 64
dataset_path: "data.jsonl"
max_length: 512
batch_size: 32

# train dynamics
lr: 5.0e-5
eval_every: 2000
eval_steps: 100
save_every: 2000
output_dir: "ckpts/llama-7b"
checkpoint: null
lora: false
warmup_steps: 100

# logging
wandb: false
wandb_entity: zanussbaum
wandb_project: llama
seed: 42

29 changes: 29 additions & 0 deletions configs/train/finetune_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# model/tokenizer
model_name: "zpn/llama-7b"
tokenizer_name: "zpn/llama-7b"
gradient_checkpointing: false
save_name: "zpn/vicuna-lora"

# dataset
streaming: false
num_proc: 64
dataset_path: "data"
max_length: 512
batch_size: 8

# train dynamics
lr: 5.0e-5
eval_every: 2000
eval_steps: 100
save_every: 2000
output_dir: "ckpts/llama-7b"
checkpoint: null
lora: true
warmup_steps: 100

# logging
wandb: false
wandb_entity: zanussbaum
wandb_project: llama
seed: 42

108 changes: 108 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import glob
import torch
from datasets import load_dataset
import os
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator



def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"]
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id)
# ignore bos
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:]

out = {"labels": [], "attention_mask": []}
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
# HACK to get 512 to work for now
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze()
input_len = len(input_tokens)

# plus one since we remove bos from response
remaining_tokens = max_length - input_len - len(newline_tokens) + 1

target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]

input_ids[i, :input_len] = input_tokens
# add newline between prompt and response
newline_plus_inputs = input_len + len(newline_tokens)
input_ids[i, input_len: newline_plus_inputs] = newline_tokens
# add target tokens, remove bos
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens

labels = input_ids[i].clone()
labels[: newline_plus_inputs] = -100
labels[labels == tokenizer.pad_token_id] = -100
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response

attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int()

out["labels"].append(labels)
out["attention_mask"].append(attention_mask)

out["input_ids"] = input_ids

out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}

return out



def load_data(config, tokenizer):
dataset_path = config["dataset_path"]

if os.path.exists(dataset_path):
# check if path is a directory
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]

dataset = load_dataset("json", data_files=files, split="train")

else:
dataset = load_dataset(dataset_path)


dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])

train_dataset, val_dataset = dataset["train"], dataset["test"]

if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}

# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)

train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")

# create dataloader with default data collator since we already have labels

train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
)

val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
)

return train_dataloader, val_dataloader
10 changes: 10 additions & 0 deletions read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import yaml


def read_config(path):
# read yaml and return contents
with open(path, 'r') as file:
try:
return yaml.safe_load(file)
except yaml.YAMLError as exc:
print(exc)
Loading

0 comments on commit 723a50b

Please sign in to comment.