Skip to content

Commit

Permalink
WIP for axolotl trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 14, 2023
1 parent e9da4b9 commit ce24f5e
Show file tree
Hide file tree
Showing 16 changed files with 497 additions and 1 deletion.
14 changes: 14 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
root = true

[*]
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true

[*.py]
indent_style = space
indent_size = 4

[**.yml]
indent_style = space
indent_size = 2
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
**/axolotl.egg-info
**/__pycache__
.idea
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Axolotl

### You know you're going to axolotl questions
#### You know you're going to axolotl questions


### Converting JSON data files to JSONL

```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
37 changes: 37 additions & 0 deletions configs/pythia_1_2B_alpaca.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
base_model: EleutherAI/pythia-1.4b-deduped
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
datasets:
- path: ./data/alpaca_data_gpt4.jsonl
type: alpaca
- path: ./data/vicuna_cleaned.jsonl
type: sharegpt
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
type: gpteacher
val_set_size: 0.05
adapter: lora
sequence_len: 2048
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
wandb_project:
wandb_watch:
wandb:run_name:
wandb_log_model: checkpoint
output_dir: ./lora-alpaca
batch_size: 128
micro_batch_size: 8
num_epochs: 5
learning_rate: 0.0003
train_on_inputs: false
bf16: True
fp16: True
resume_from_checkpoint:
local_rank:
deepspeed:
8 changes: 8 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o raw/roleplay-similarity_0.6-instruct-dataset.json
```
1 change: 1 addition & 0 deletions data/raw/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
git+https://github.com/huggingface/transformers.git
git+https://github.com/huggingface/peft.git
attrdict
fire
PyYAML==6.0
black
36 changes: 36 additions & 0 deletions scripts/alpaca_json_to_jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import sys
from pathlib import Path

import fire
from typing import Optional

# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir)

from axolotl.convert import *

def main(
input: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
file_reader = FileReader()
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()

converter = JsonToJsonlConverter(
file_reader, writer, json_parser, jsonl_serializer
)

converter.convert(input, output)


if __name__ == "__main__":
fire.Fire(main)
129 changes: 129 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import sys
from pathlib import Path

import fire
import torch
import transformers
import yaml
from attrdict import AttrDict
from datasets import load_dataset, IterableDataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
)
from transformers import AutoModelForCausalLM, AutoTokenizer

# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.insert(0, src_dir)

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter

def setup_wandb_env_vars(cfg):
if len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model


def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available")
try:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
except:
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)

try:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)

if tokenizer.__class__.__name__ == "LlamaTokenizer":
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN

if cfg.load_in_8bit:
model = prepare_model_for_int8_training(model)

lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.lora_target_modules,
lora_dropout=cfg.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
if cfg.ddp:
model.to(f"cuda:{cfg.local_rank}")

# TODO resume_from_checkpoint handling

model.print_trainable_parameters()
return model, tokenizer


def train(
config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
**kwargs,
):
# load the config from the yaml file
with open(config, 'r') as f:
cfg: AttrDict = AttrDict(yaml.load(f))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
for k, v in enumerate(kwargs):
if k in cfg:
cfg.k = v

# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.device_map = "auto"
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.ddp = cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
setup_wandb_env_vars(cfg)

# Load the model and tokenizer
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
datasets = []
for d in cfg.datasets:
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d.type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d.type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)


if __name__ == "__main__":
fire.Fire(train)
23 changes: 23 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[metadata]
name = axolotl
version = 0.1.0
description = You know you're going to axolotl questions
author = Wing Lian
author_email = [email protected]
license = MIT
[options]
package_dir =
=src
packages = find:
install_requires =
transformers @ git+https://github.com/huggingface/transformers.git@main
peft @ git+https://github.com/huggingface/peft.git@main
attrdict
fire
PyYAML == 6.0
black
[options.packages.find]
where = src
Empty file added src/axolotl/__init__.py
Empty file.
50 changes: 50 additions & 0 deletions src/axolotl/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json
import sys


class FileReader:
def read(self, file_path):
with open(file_path, "r") as file:
return file.read()


class FileWriter:
def __init__(self, file_path):
self.file_path = file_path

def write(self, content):
with open(self.file_path, "w") as file:
file.write(content)


class StdoutWriter:
def write(self, content):
sys.stdout.write(content)
sys.stdout.write("\n")


class JsonParser:
def parse(self, content):
return json.loads(content)


class JsonlSerializer:
def serialize(self, data):
lines = [json.dumps(item) for item in data]
return "\n".join(lines)


class JsonToJsonlConverter:
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
self.file_reader = file_reader
self.file_writer = file_writer
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer

def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
jsonl_content = self.jsonl_serializer.serialize(data)
self.file_writer.write(jsonl_content)


Loading

0 comments on commit ce24f5e

Please sign in to comment.