forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_dolly.py
155 lines (126 loc) · 5.52 KB
/
prepare_dolly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
import json
import sys
from pathlib import Path
from typing import Optional
import requests
import torch
from torch.utils.data import random_split
from tqdm import tqdm
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from lit_gpt.tokenizer import Tokenizer
def prepare(
destination_path: Path = Path("data/dolly"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
test_split_fraction: float = 0.1,
seed: int = 42,
mask_inputs: bool = False,
data_file_name: str = "dolly_data_cleaned.json",
data_file_url: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl",
ignore_index: int = -1,
max_seq_length: Optional[int] = None,
) -> None:
"""Prepare the Dolly 15k dataset for instruction tuning.
The output is a training and test dataset saved as `train.pt` and `test.pt`,
which stores the preprocessed and tokenized prompts and labels.
"""
if max_seq_length is None:
with open(checkpoint_dir / "lit_config.json", "r", encoding="utf-8") as file:
config = json.load(file)
max_seq_length = config["block_size"]
destination_path.mkdir(parents=True, exist_ok=True)
data_file_path = destination_path / data_file_name
print("Loading data file...")
download_if_missing(data_file_path, data_file_url)
with open(data_file_path, "r", encoding="utf-8") as file:
data = file.readlines()
data = [json.loads(line) for line in data]
for item in data:
item["input"] = item.pop("context")
item["output"] = item.pop("response")
print("Loading tokenizer...")
tokenizer = Tokenizer(checkpoint_dir)
# Partition the dataset into train and test
train_set, test_set = random_split(
data, [1.0 - test_split_fraction, test_split_fraction], generator=torch.Generator().manual_seed(seed)
)
train_set, test_set = list(train_set), list(test_set)
print(f"train has {len(train_set):,} samples")
print(f"test has {len(test_set):,} samples")
print("Processing train split ...")
train_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(train_set)
]
torch.save(train_set, destination_path / "train.pt")
print("Processing test split ...")
test_set = [
prepare_sample(
example=sample,
tokenizer=tokenizer,
max_length=max_seq_length,
mask_inputs=mask_inputs,
ignore_index=ignore_index,
)
for sample in tqdm(test_set)
]
torch.save(test_set, destination_path / "test.pt")
def download_if_missing(file_path: Path, file_url: str) -> None:
"""Downloads the raw json data file and saves it in the given destination."""
if file_path.exists():
return
with open(file_path, "w", encoding="utf-8") as f:
f.write(requests.get(file_url).text)
def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> None:
"""Processes a single sample.
Each sample in the dataset consists of:
- instruction: A string describing the task
- input: A string holding a special input value for the instruction.
This only applies to some samples, and in others this is empty.
- output: The response string
This function processes this data to produce a prompt text and a label for
supervised training. The prompt text is formed as a single message including both
the instruction and the input. The label/target is the same message but with the
response attached.
Finally, both the prompt and the label get tokenized. If desired, all tokens
in the label that correspond to the original input prompt get masked out (default).
"""
full_prompt = generate_prompt(example)
full_prompt_and_response = full_prompt + example["output"]
encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)
encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)
# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_full_prompt_and_response.clone()
if mask_inputs:
labels[: len(encoded_full_prompt)] = ignore_index
return {
**example,
"input_ids": encoded_full_prompt_and_response,
"input_ids_no_response": encoded_full_prompt,
"labels": labels,
}
def generate_prompt(example: dict) -> str:
"""Generates a standardized message to prompt the model with an instruction, optional input and a
'response' field."""
if example["input"]:
return (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:"
)
return (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{example['instruction']}\n\n### Response:"
)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(prepare)