Skip to content

Commit

Permalink
Exercise solution for LoRA instruction finetuning (rasbt#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jun 22, 2024
1 parent ec5baa1 commit 72f4629
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 4 deletions.
126 changes: 126 additions & 0 deletions ch07/01_main-chapter-code/exercise-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,132 @@
"source": [
"The score is slightly lower than the score we obtained on the dataset we used in this chapter. However, note that the Alpaca test set contains more diverse and partly more challenging instructions than the dataset we used in the main chapter."
]
},
{
"cell_type": "markdown",
"id": "ca61fa6c-4e1d-4618-9e5e-d091f8303e30",
"metadata": {},
"source": [
"## Exercise 7.4: Parameter-efficient finetuning with LoRA"
]
},
{
"cell_type": "markdown",
"id": "01742cec-1f41-4415-8788-009d31b1ad38",
"metadata": {},
"source": [
"To instruction finetune the model using LoRA, use the relevant classes and functions from appendix E:\n",
"\n",
"```python\n",
" from appendix_E import LoRALayer, LinearWithLoRA, replace_linear_with_lora\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "871dca8f-3411-4735-b7b0-9d0e6e0599ac",
"metadata": {},
"source": [
"Next, add the following lines of code below the model loading code in section 7.5:\n",
"\n",
"\n",
"```python\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters before: {total_params:,}\")\n",
"\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters after: {total_params:,}\")\n",
"replace_linear_with_lora(model, rank=16, alpha=16)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable LoRA parameters: {total_params:,}\")\n",
"model.to(device)\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "1b26b925-dc95-4b91-b050-9676dd9608a4",
"metadata": {},
"source": [
"For your convenience, you can use the `exercise_experiments.py` code to finetune the model, using LoRA with rank 16 and alpa 16, as follows:"
]
},
{
"cell_type": "markdown",
"id": "01f02c7e-3b15-44b8-bf41-7892cd755766",
"metadata": {},
"source": [
"```bash\n",
"python exercise_experiments.py --exercise_solution lora\n",
"```\n",
"\n",
"Output:\n",
"\n",
"```\n",
"matplotlib version: 3.7.1\n",
"tiktoken version: 0.7.0\n",
"torch version: 2.3.0+cu121\n",
"tqdm version: 4.66.4\n",
"tensorflow version: 2.15.0\n",
"--------------------------------------------------\n",
"Training set length: 935\n",
"Validation set length: 55\n",
"Test set length: 110\n",
"--------------------------------------------------\n",
"Device: cuda\n",
"--------------------------------------------------\n",
"File already exists and is up-to-date: gpt2/355M/checkpoint\n",
"File already exists and is up-to-date: gpt2/355M/encoder.json\n",
"File already exists and is up-to-date: gpt2/355M/hparams.json\n",
"File already exists and is up-to-date: gpt2/355M/model.ckpt.data-00000-of-00001\n",
"File already exists and is up-to-date: gpt2/355M/model.ckpt.index\n",
"File already exists and is up-to-date: gpt2/355M/model.ckpt.meta\n",
"File already exists and is up-to-date: gpt2/355M/vocab.bpe\n",
"Loaded model: gpt2-medium (355M)\n",
"--------------------------------------------------\n",
"Total trainable parameters before: 406,286,336\n",
"Total trainable parameters after: 0\n",
"Total trainable LoRA parameters: 7,898,384\n",
"Initial losses\n",
" Training loss: 3.7684114456176756\n",
" Validation loss: 3.7619335651397705\n",
"Ep 1 (Step 000000): Train loss 2.509, Val loss 2.519\n",
"...\n",
"Ep 2 (Step 000230): Train loss 0.308, Val loss 0.652\n",
"...\n",
"--------------------------------------------------\n",
"Generating responses\n",
"100% 110/110 [01:52<00:00, 1.03s/it]\n",
"Responses saved as instruction-data-with-response-lora.json\n",
"Model saved as gpt2-medium355M-sft-lora.pth\n",
"```\n",
"\n",
"For comparison, you can run the original chapter 7 finetuning code via `python exercise_experiments.py --exercise_solution baseline`. \n",
"\n",
"Note that on an Nvidia L4 GPU, the code above, using LoRA, takes 1.30 min to run. In comparison, the Alpaca-style template takes 1.80 minutes to run. So, LoRA is approximately 28% faster.\n",
"\n",
"\n",
"We can evaluate the performance using the Ollama Llama 3 method, which is for your convenience, also implemented in the `python exercise_experiments.py` script, which we can run as follows:\n",
"\n",
"```python\n",
"python ollama_evaluate.py --file_path instruction-data-with-response-lora.json\n",
"```\n",
"\n",
"Output:\n",
"\n",
"```\n",
"Ollama running: True\n",
"Scoring entries: 100%|████████████████████████| 110/110 [01:13<00:00, 1.50it/s]\n",
"Number of scores: 110 of 110\n",
"Average score: 50.23\n",
"```\n",
"\n",
"The score is around 50, which is in the same ballpark as original model."
]
}
],
"metadata": {
Expand Down
66 changes: 62 additions & 4 deletions ch07/01_main-chapter-code/exercise_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import partial
from importlib.metadata import version
import json
import math
import os
import re
import time
Expand Down Expand Up @@ -107,6 +108,41 @@ def __len__(self):
return len(self.data)


class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)

def forward(self, x):
return self.linear(x) + self.lora(x)


class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha

def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x


def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)


def custom_collate_fn(
batch,
pad_token_id=50256,
Expand Down Expand Up @@ -256,7 +292,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, plot_name):
# plt.show()


def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False, lora=False):
#######################################
# Print package versions
#######################################
Expand Down Expand Up @@ -379,6 +415,21 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
print("Loaded model:", CHOOSE_MODEL)
print(50*"-")

if lora:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters before: {total_params:,}")

for param in model.parameters():
param.requires_grad = False

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters after: {total_params:,}")
replace_linear_with_lora(model, rank=16, alpha=16)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable LoRA parameters: {total_params:,}")
model.to(device)

#######################################
# Finetuning the model
#######################################
Expand Down Expand Up @@ -418,7 +469,9 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
plot_name = plot_name.replace(".pdf", "-alpaca52k.pdf")
if phi3_prompt:
plot_name = plot_name.replace(".pdf", "-phi3-prompt.pdf")
if not any([mask_instructions, alpaca52k, phi3_prompt]):
if lora:
plot_name = plot_name.replace(".pdf", "-lora.pdf")
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
plot_name = plot_name.replace(".pdf", "-baseline.pdf")

plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, plot_name)
Expand Down Expand Up @@ -460,7 +513,10 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
if phi3_prompt:
test_data_path = test_data_path.replace(".json", "-phi3-prompt.json")
file_name = file_name.replace(".pth", "-phi3-prompt.pth")
if not any([mask_instructions, alpaca52k, phi3_prompt]):
if lora:
test_data_path = test_data_path.replace(".json", "-lora.json")
file_name = file_name.replace(".pth", "-lora.pth")
if not any([mask_instructions, alpaca52k, phi3_prompt, lora]):
test_data_path = test_data_path.replace(".json", "-baseline.json")
file_name = file_name.replace(".pth", "-baseline.pth")

Expand All @@ -479,7 +535,7 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
parser = argparse.ArgumentParser(
description="Instruction finetune a GPT model"
)
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt"}
options = {"baseline", "mask_instructions", "alpaca_52k", "phi3_prompt", "lora"}
parser.add_argument(
"--exercise_solution",
type=str,
Expand All @@ -498,5 +554,7 @@ def main(mask_instructions=False, alpaca52k=False, phi3_prompt=False):
main(alpaca52k=True)
elif args.exercise_solution == "phi3_prompt":
main(phi3_prompt=True)
elif args.exercise_solution == "lora":
main(lora=True)
else:
raise ValueError(f"{args.exercise_solution} is not a valid --args.exercise_solution option. Options: {options}")

0 comments on commit 72f4629

Please sign in to comment.