forked from axolotl-ai-cloud/axolotl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
497 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
root = true | ||
|
||
[*] | ||
end_of_line = lf | ||
insert_final_newline = true | ||
trim_trailing_whitespace = true | ||
|
||
[*.py] | ||
indent_style = space | ||
indent_size = 4 | ||
|
||
[**.yml] | ||
indent_style = space | ||
indent_size = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
**/axolotl.egg-info | ||
**/__pycache__ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,13 @@ | ||
# Axolotl | ||
|
||
### You know you're going to axolotl questions | ||
#### You know you're going to axolotl questions | ||
|
||
|
||
### Converting JSON data files to JSONL | ||
|
||
```shell | ||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl | ||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl | ||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl | ||
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
base_model: EleutherAI/pythia-1.4b-deduped | ||
model_type: GPTNeoXForCausalLM | ||
tokenizer_type: AutoTokenizer | ||
load_in_8bit: true | ||
datasets: | ||
- path: ./data/alpaca_data_gpt4.jsonl | ||
type: alpaca | ||
- path: ./data/vicuna_cleaned.jsonl | ||
type: sharegpt | ||
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl | ||
type: gpteacher | ||
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl | ||
type: gpteacher | ||
val_set_size: 0.05 | ||
adapter: lora | ||
sequence_len: 2048 | ||
lora_r: 16 | ||
lora_alpha: 32 | ||
lora_dropout: 0.05 | ||
lora_target_modules: | ||
- q_proj | ||
- v_proj | ||
wandb_project: | ||
wandb_watch: | ||
wandb:run_name: | ||
wandb_log_model: checkpoint | ||
output_dir: ./lora-alpaca | ||
batch_size: 128 | ||
micro_batch_size: 8 | ||
num_epochs: 5 | ||
learning_rate: 0.0003 | ||
train_on_inputs: false | ||
bf16: True | ||
fp16: True | ||
resume_from_checkpoint: | ||
local_rank: | ||
deepspeed: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
|
||
```shell | ||
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o raw/alpaca_data_gpt4.json | ||
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o raw/vicuna_cleaned.json | ||
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o raw/gpt4-instruct-similarity-0.6-dataset.json | ||
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o raw/roleplay-similarity_0.6-instruct-dataset.json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel"] | ||
build-backend = "setuptools.build_meta" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
git+https://github.com/huggingface/transformers.git | ||
git+https://github.com/huggingface/peft.git | ||
attrdict | ||
fire | ||
PyYAML==6.0 | ||
black |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import sys | ||
from pathlib import Path | ||
|
||
import fire | ||
from typing import Optional | ||
|
||
# add src to the pythonpath so we don't need to pip install this | ||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | ||
src_dir = os.path.join(project_root, 'src') | ||
sys.path.insert(0, src_dir) | ||
|
||
from axolotl.convert import * | ||
|
||
def main( | ||
input: Path, | ||
output: Optional[Path] = None, | ||
to_stdout: Optional[bool] = False, | ||
): | ||
file_reader = FileReader() | ||
if to_stdout or output is None: | ||
writer = StdoutWriter() | ||
else: | ||
writer = FileWriter(output) | ||
json_parser = JsonParser() | ||
jsonl_serializer = JsonlSerializer() | ||
|
||
converter = JsonToJsonlConverter( | ||
file_reader, writer, json_parser, jsonl_serializer | ||
) | ||
|
||
converter.convert(input, output) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import os | ||
import sys | ||
from pathlib import Path | ||
|
||
import fire | ||
import torch | ||
import transformers | ||
import yaml | ||
from attrdict import AttrDict | ||
from datasets import load_dataset, IterableDataset | ||
from peft import ( | ||
LoraConfig, | ||
get_peft_model, | ||
prepare_model_for_int8_training, | ||
) | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
# add src to the pythonpath so we don't need to pip install this | ||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | ||
src_dir = os.path.join(project_root, 'src') | ||
sys.path.insert(0, src_dir) | ||
|
||
from axolotl.datasets import TokenizedPromptDataset | ||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \ | ||
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy | ||
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter | ||
|
||
def setup_wandb_env_vars(cfg): | ||
if len(cfg.wandb_project) > 0: | ||
os.environ["WANDB_PROJECT"] = cfg.wandb_project | ||
cfg.use_wandb = True | ||
if len(cfg.wandb_watch) > 0: | ||
os.environ["WANDB_WATCH"] = cfg.wandb_watch | ||
if len(cfg.wandb_log_model) > 0: | ||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model | ||
|
||
|
||
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): | ||
if adapter != "lora": | ||
raise NotImplementedError(f"{adapter} peft adapter not available") | ||
try: | ||
model = getattr(transformers, model_type).from_pretrained( | ||
base_model, | ||
load_in_8bit=cfg.load_in_8bit, | ||
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, | ||
device_map=cfg.device_map, | ||
) | ||
except: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
base_model, | ||
load_in_8bit=cfg.load_in_8bit, | ||
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32, | ||
device_map=cfg.device_map, | ||
) | ||
|
||
try: | ||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) | ||
except: | ||
tokenizer = AutoTokenizer.from_pretrained(base_model) | ||
|
||
if tokenizer.__class__.__name__ == "LlamaTokenizer": | ||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN | ||
|
||
if cfg.load_in_8bit: | ||
model = prepare_model_for_int8_training(model) | ||
|
||
lora_config = LoraConfig( | ||
r=cfg.lora_r, | ||
lora_alpha=cfg.lora_alpha, | ||
target_modules=cfg.lora_target_modules, | ||
lora_dropout=cfg.lora_dropout, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
model = get_peft_model(model, lora_config) | ||
if cfg.ddp: | ||
model.to(f"cuda:{cfg.local_rank}") | ||
|
||
# TODO resume_from_checkpoint handling | ||
|
||
model.print_trainable_parameters() | ||
return model, tokenizer | ||
|
||
|
||
def train( | ||
config: Path = Path('configs/pythia_1_2B_alpaca.yml'), | ||
**kwargs, | ||
): | ||
# load the config from the yaml file | ||
with open(config, 'r') as f: | ||
cfg: AttrDict = AttrDict(yaml.load(f)) | ||
# if there are any options passed in the cli, if it is something that seems valid from the yaml, | ||
# then overwrite the value | ||
for k, v in enumerate(kwargs): | ||
if k in cfg: | ||
cfg.k = v | ||
|
||
# setup some derived config / hyperparams | ||
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size | ||
cfg.device_map = "auto" | ||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) | ||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) | ||
cfg.ddp = cfg.world_size != 1 | ||
if cfg.ddp: | ||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} | ||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size | ||
setup_wandb_env_vars(cfg) | ||
|
||
# Load the model and tokenizer | ||
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter) | ||
datasets = [] | ||
for d in cfg.datasets: | ||
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None) | ||
if d.type == "alpaca": | ||
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) | ||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) | ||
datasets.append(ds_wrapper) | ||
elif d.type == "gpteacher": | ||
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) | ||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) | ||
datasets.append(ds_wrapper) | ||
elif d.type == "sharegpt": | ||
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len) | ||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) | ||
datasets.append(ds_wrapper) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(train) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[metadata] | ||
name = axolotl | ||
version = 0.1.0 | ||
description = You know you're going to axolotl questions | ||
author = Wing Lian | ||
author_email = [email protected] | ||
license = MIT | ||
[options] | ||
package_dir = | ||
=src | ||
packages = find: | ||
install_requires = | ||
transformers @ git+https://github.com/huggingface/transformers.git@main | ||
peft @ git+https://github.com/huggingface/peft.git@main | ||
attrdict | ||
fire | ||
PyYAML == 6.0 | ||
black | ||
[options.packages.find] | ||
where = src | ||
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import json | ||
import sys | ||
|
||
|
||
class FileReader: | ||
def read(self, file_path): | ||
with open(file_path, "r") as file: | ||
return file.read() | ||
|
||
|
||
class FileWriter: | ||
def __init__(self, file_path): | ||
self.file_path = file_path | ||
|
||
def write(self, content): | ||
with open(self.file_path, "w") as file: | ||
file.write(content) | ||
|
||
|
||
class StdoutWriter: | ||
def write(self, content): | ||
sys.stdout.write(content) | ||
sys.stdout.write("\n") | ||
|
||
|
||
class JsonParser: | ||
def parse(self, content): | ||
return json.loads(content) | ||
|
||
|
||
class JsonlSerializer: | ||
def serialize(self, data): | ||
lines = [json.dumps(item) for item in data] | ||
return "\n".join(lines) | ||
|
||
|
||
class JsonToJsonlConverter: | ||
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer): | ||
self.file_reader = file_reader | ||
self.file_writer = file_writer | ||
self.json_parser = json_parser | ||
self.jsonl_serializer = jsonl_serializer | ||
|
||
def convert(self, input_file_path, output_file_path): | ||
content = self.file_reader.read(input_file_path) | ||
data = self.json_parser.parse(content) | ||
jsonl_content = self.jsonl_serializer.serialize(data) | ||
self.file_writer.write(jsonl_content) | ||
|
||
|
Oops, something went wrong.