Skip to content

Commit

Permalink
Update finetuning scripts (Lightning-AI#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent d9e16d2 commit b1423fa
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 33 deletions.
1 change: 1 addition & 0 deletions config_hub/finetune/llama-2-7b/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ data:
init_args:
mask_prompt: false
test_split_fraction: 0.03847
prompt_style: "alpaca"
ignore_index: -1
seed: 42
num_workers: 4
Expand Down
1 change: 1 addition & 0 deletions config_hub/finetune/llama-2-7b/lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data:
init_args:
mask_prompt: false
test_split_fraction: 0.03847
prompt_style: "alpaca"
ignore_index: -1
seed: 42
num_workers: 4
Expand Down
1 change: 1 addition & 0 deletions config_hub/finetune/tiny-llama/lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data:
init_args:
mask_prompt: false
test_split_fraction: 0.03847
prompt_style: "alpaca"
ignore_index: -1
seed: 42
num_workers: 4
Expand Down
10 changes: 5 additions & 5 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def fit(
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.barrier()
if not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
Expand All @@ -241,10 +241,10 @@ def validate(
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
batch = next(val_iterator)
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
Expand Down
10 changes: 5 additions & 5 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def fit(
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
fabric.barrier()
if not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_adapter_v2_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
Expand All @@ -241,10 +241,10 @@ def validate(
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
batch = next(val_iterator)
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
Expand Down
12 changes: 7 additions & 5 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ def fit(
"loss": loss,
"iter": state["iter_num"],
"step": state["step_count"],
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": state["iter_num"] * train.micro_batch_size * model.config.block_size,
"total_tokens": (
state["iter_num"] * train.micro_batch_size * model.config.block_size * fabric.world_size
),
# TODO: log learning rate
"learning_rate": scheduler.get_last_lr()[0],
}
fabric.print(
f"iter {metrics['iter']} | step {metrics['step']}: loss {metrics['loss']:.4f}, iter time:"
Expand Down Expand Up @@ -260,13 +261,14 @@ def validate(
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
batch = next(val_iterator)
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)

val_loss = losses.mean()

# produce an example:
Expand Down
41 changes: 31 additions & 10 deletions finetune/lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import dataclasses
import math
import os
import sys
import time
Expand All @@ -13,6 +14,7 @@
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor
from torchmetrics import RunningMean

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -127,7 +129,6 @@ def setup(

def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(train, eval)

check_valid_checkpoint_dir(checkpoint_dir)

tokenizer = Tokenizer(checkpoint_dir)
Expand Down Expand Up @@ -209,6 +210,9 @@ def fit(

train_iterator = CycleIterator(train_dataloader)
throughput = ThroughputMonitor(fabric, window_size=50)
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
fabric.device
)
max_steps = train.max_steps or float("inf")
step_count = 0
iter_num = 0
Expand All @@ -218,7 +222,6 @@ def fit(
while step_count < max_steps and train_iterator.epoch < train.epochs:
iter_num += 1
iter_t0 = time.perf_counter()

batch = next(train_iterator)
input_ids, targets = batch["input_ids"], batch["labels"]

Expand All @@ -230,6 +233,8 @@ def fit(
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / train.gradient_accumulation_iters(devices))

running_loss.update(loss.detach())

if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
Expand All @@ -238,25 +243,40 @@ def fit(

total_lengths += input_ids.numel()
if iter_num % train.log_interval == 0:
loss_item = loss.item() # expensive device-to-host synchronization
loss = running_loss.compute().item() # expensive device-to-host synchronization
t1 = time.perf_counter()
throughput.update(
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
)
throughput.compute_and_log(step=iter_num)
metrics = {
"loss": loss,
"iter": iter_num,
"step": step_count,
"epoch": train_iterator.epoch,
"iter_time": t1 - iter_t0,
"tokens": iter_num * train.micro_batch_size * model.config.block_size,
"total_tokens": (
iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size
),
"learning_rate": scheduler.get_last_lr()[0],
}
fabric.print(
f"iter {iter_num} | step {step_count}: loss {loss_item:.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}"
f"iter {metrics['iter']} | step {metrics['step']}: loss {metrics['loss']:.4f}, iter time:"
f" {metrics['iter_time'] * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}"
)
fabric.log_dict(metrics, step=iter_num)

if not is_accumulating and step_count % eval.interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
t1 = time.perf_counter() - t0
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
fabric.log_dict(metrics, step=iter_num)
fabric.barrier()
if not is_accumulating and step_count % train.save_interval == 0:
checkpoint_file = out_dir / f"iter-{iter_num:06d}" / "lit_model.pth"
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
save_lora_checkpoint(fabric, model, checkpoint_file)
if fabric.global_rank == 0:
Expand All @@ -272,13 +292,14 @@ def validate(
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval.max_iters)
val_iterator = iter(val_dataloader)
for k in range(eval.max_iters):
batch = next(val_iterator)
losses = torch.zeros(min(len(val_dataloader), eval.max_iters))
for k, batch in enumerate(val_dataloader):
if k >= eval.max_iters:
break
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)

val_loss = losses.mean()

# produce an example:
Expand Down
1 change: 0 additions & 1 deletion generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from lit_gpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint



def main(
prompt: str = "What food do llamas eat?",
input: str = "",
Expand Down
8 changes: 4 additions & 4 deletions lit_gpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ def validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max
fabric.print("Validating ...")
model.eval()

losses = torch.zeros(max_iters, device=fabric.device)
for k, val_data in enumerate(val_dataloader):
losses = torch.zeros(min(len(val_dataloader), max_iters))
for k, batch in enumerate(val_dataloader):
if k >= max_iters:
break
input_ids = val_data[:, 0 : model.config.block_size].contiguous().long()
targets = val_data[:, 1 : (model.config.block_size + 1)].contiguous().long()
input_ids = batch[:, 0 : model.config.block_size].contiguous().long()
targets = batch[:, 1 : (model.config.block_size + 1)].contiguous().long()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets)
losses[k] = loss
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path)
)

out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"iter-000002", "iter-000004", "iter-000006", "final"}
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_pa
)

out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"iter-000002", "iter-000004", "iter-000006", "final"}
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):
)

out_dir_contents = set(os.listdir(out_dir))
checkpoint_dirs = {"iter-000002", "iter-000004", "iter-000006", "final"}
checkpoint_dirs = {"step-000002", "step-000004", "step-000006", "final"}
assert checkpoint_dirs.issubset(out_dir_contents)
assert all((out_dir / p).is_dir() for p in checkpoint_dirs)
for checkpoint_dir in checkpoint_dirs:
Expand Down

0 comments on commit b1423fa

Please sign in to comment.