forked from nomic-ai/gpt4all
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2568d94
commit 723a50b
Showing
7 changed files
with
481 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import glob | ||
import os | ||
import json | ||
import jsonlines | ||
import pandas as pd | ||
|
||
|
||
prompt_generation_dir = "prompts-reponses" | ||
for file in glob.glob(os.path.join(prompt_generation_dir, "*.jsonl")): | ||
data = [] | ||
print(file) | ||
with open(file) as f: | ||
for line in f: | ||
try: | ||
contents = json.loads(line) | ||
data.append(contents) | ||
except BaseException: | ||
pass | ||
|
||
processed = [] | ||
|
||
for item in data: | ||
if 'source' not in item: | ||
item['source'] = 'unspecified' | ||
if 'model_settings' in item: | ||
item.pop('model_settings', None) | ||
|
||
for key in list(item.keys()): | ||
if key not in ['source', 'prompt', 'response']: | ||
#print(item[key]) | ||
item.pop(key, None) | ||
|
||
if isinstance(item['prompt'], dict): | ||
if "value" in item["prompt"]: | ||
item["prompt"] = item["prompt"]["value"] | ||
elif "description" in item["prompt"]: | ||
item["prompt"] = item["prompt"]["description"] | ||
else: | ||
continue | ||
|
||
elif not isinstance(item['prompt'], str): | ||
continue | ||
|
||
if isinstance(item['response'], dict): | ||
if "value" in item["response"]: | ||
item["response"] = item["response"]["value"] | ||
elif "description" in item["response"]: | ||
item["response"] = item["response"]["description"] | ||
else: | ||
continue | ||
elif not isinstance(item['response'], str): | ||
continue | ||
|
||
if item: | ||
processed.append(item) | ||
|
||
df = pd.DataFrame(processed) | ||
prev_len = len(df) | ||
|
||
# drop empty or null string | ||
df = df.dropna(subset=['prompt', 'response']) | ||
df = df[df['prompt'] != ''] | ||
df = df[df['response'] != ''] | ||
curr_len = len(df) | ||
|
||
print(f"Removed {prev_len - curr_len} rows") | ||
|
||
clean_name = file.split(".jsonl")[0] + "_clean.jsonl" | ||
print(f"writing to {clean_name}") | ||
df.to_json(clean_name, orient="records", lines=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
{ | ||
"train_batch_size": "auto", | ||
"gradient_accumulation_steps": "auto", | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"fp16": { | ||
"enabled": "auto", | ||
"min_loss_scale": 1, | ||
"loss_scale_window": 1000, | ||
"hysteresis": 2, | ||
"initial_scale_power": 32 | ||
}, | ||
"bf16": { | ||
"enabled": "auto" | ||
}, | ||
"gradient_clipping": 1, | ||
"zero_optimization": { | ||
"stage": 2, | ||
"offload_param": { | ||
"device": "none" | ||
}, | ||
"offload_optimizer": { | ||
"device": "none" | ||
}, | ||
"allgather_partitions": true, | ||
"allgather_bucket_size": 5e8, | ||
"contiguous_gradients": true | ||
}, | ||
"optimizer": { | ||
"type": "AdamW", | ||
"params": { | ||
"lr": "auto", | ||
"betas": [ | ||
0.9, | ||
0.999 | ||
], | ||
"eps": 1e-08 | ||
} | ||
}, | ||
"scheduler": { | ||
"type": "WarmupLR", | ||
"params": { | ||
"warmup_min_lr": 0, | ||
"warmup_max_lr": "auto", | ||
"warmup_num_steps": "auto", | ||
"warmup_type": "linear" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# model/tokenizer | ||
model_name: "zpn/llama-7b" | ||
tokenizer_name: "zpn/llama-7b" | ||
gradient_checkpointing: true | ||
|
||
# dataset | ||
streaming: false | ||
num_proc: 64 | ||
dataset_path: "data.jsonl" | ||
max_length: 512 | ||
batch_size: 32 | ||
|
||
# train dynamics | ||
lr: 5.0e-5 | ||
eval_every: 2000 | ||
eval_steps: 100 | ||
save_every: 2000 | ||
output_dir: "ckpts/llama-7b" | ||
checkpoint: null | ||
lora: false | ||
warmup_steps: 100 | ||
|
||
# logging | ||
wandb: false | ||
wandb_entity: zanussbaum | ||
wandb_project: llama | ||
seed: 42 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# model/tokenizer | ||
model_name: "zpn/llama-7b" | ||
tokenizer_name: "zpn/llama-7b" | ||
gradient_checkpointing: false | ||
save_name: "zpn/vicuna-lora" | ||
|
||
# dataset | ||
streaming: false | ||
num_proc: 64 | ||
dataset_path: "data" | ||
max_length: 512 | ||
batch_size: 8 | ||
|
||
# train dynamics | ||
lr: 5.0e-5 | ||
eval_every: 2000 | ||
eval_steps: 100 | ||
save_every: 2000 | ||
output_dir: "ckpts/llama-7b" | ||
checkpoint: null | ||
lora: true | ||
warmup_steps: 100 | ||
|
||
# logging | ||
wandb: false | ||
wandb_entity: zanussbaum | ||
wandb_project: llama | ||
seed: 42 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import glob | ||
import torch | ||
from datasets import load_dataset | ||
import os | ||
from torch.utils.data import DataLoader | ||
from transformers import DefaultDataCollator | ||
|
||
|
||
|
||
def tokenize_inputs(config, tokenizer, examples): | ||
max_length = config["max_length"] | ||
input_ids = torch.full((len(examples["prompt"]), max_length), tokenizer.pad_token_id) | ||
# ignore bos | ||
newline_tokens = tokenizer("\n", return_tensors="pt")["input_ids"][0, 1:] | ||
|
||
out = {"labels": [], "attention_mask": []} | ||
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])): | ||
# HACK to get 512 to work for now | ||
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze() | ||
input_len = len(input_tokens) | ||
|
||
# plus one since we remove bos from response | ||
remaining_tokens = max_length - input_len - len(newline_tokens) + 1 | ||
|
||
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:] | ||
|
||
input_ids[i, :input_len] = input_tokens | ||
# add newline between prompt and response | ||
newline_plus_inputs = input_len + len(newline_tokens) | ||
input_ids[i, input_len: newline_plus_inputs] = newline_tokens | ||
# add target tokens, remove bos | ||
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens | ||
|
||
labels = input_ids[i].clone() | ||
labels[: newline_plus_inputs] = -100 | ||
labels[labels == tokenizer.pad_token_id] = -100 | ||
# to debug this, can set all values == -100 to the pad token, then assert that tokenizer.decode(labels, skip_special_tokens=True).strip() == response | ||
|
||
attention_mask = input_ids[i].ne(tokenizer.pad_token_id).int() | ||
|
||
out["labels"].append(labels) | ||
out["attention_mask"].append(attention_mask) | ||
|
||
out["input_ids"] = input_ids | ||
|
||
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()} | ||
|
||
return out | ||
|
||
|
||
|
||
def load_data(config, tokenizer): | ||
dataset_path = config["dataset_path"] | ||
|
||
if os.path.exists(dataset_path): | ||
# check if path is a directory | ||
if os.path.isdir(dataset_path): | ||
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) | ||
else: | ||
files = [dataset_path] | ||
|
||
dataset = load_dataset("json", data_files=files, split="train") | ||
|
||
else: | ||
dataset = load_dataset(dataset_path) | ||
|
||
|
||
dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) | ||
|
||
train_dataset, val_dataset = dataset["train"], dataset["test"] | ||
|
||
if config["streaming"] is False: | ||
kwargs = {"num_proc": config["num_proc"]} | ||
else: | ||
kwargs = {} | ||
|
||
# tokenize inputs and return labels and attention mask | ||
train_dataset = train_dataset.map( | ||
lambda ele: tokenize_inputs(config, tokenizer, ele), | ||
batched=True, | ||
remove_columns=["source", "prompt"], | ||
**kwargs | ||
) | ||
val_dataset = val_dataset.map( | ||
lambda ele: tokenize_inputs(config, tokenizer, ele), | ||
batched=True, | ||
remove_columns=["source", "prompt"], | ||
**kwargs | ||
) | ||
|
||
train_dataset = train_dataset.with_format("torch") | ||
val_dataset = val_dataset.with_format("torch") | ||
|
||
# create dataloader with default data collator since we already have labels | ||
|
||
train_dataloader = DataLoader( | ||
train_dataset, | ||
collate_fn=DefaultDataCollator(), | ||
batch_size=config["batch_size"], | ||
) | ||
|
||
val_dataloader = DataLoader( | ||
val_dataset, | ||
collate_fn=DefaultDataCollator(), | ||
batch_size=config["batch_size"], | ||
) | ||
|
||
return train_dataloader, val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import yaml | ||
|
||
|
||
def read_config(path): | ||
# read yaml and return contents | ||
with open(path, 'r') as file: | ||
try: | ||
return yaml.safe_load(file) | ||
except yaml.YAMLError as exc: | ||
print(exc) |
Oops, something went wrong.