Skip to content

Commit

Permalink
Pretrain script with the Trainer (Lightning-AI#228)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored Jul 14, 2023
1 parent f091660 commit f50e2d8
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 17 deletions.
2 changes: 1 addition & 1 deletion finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lit_gpt.adapter import GPT, Config, mark_only_adapter_as_trainable, Block, adapter_filter
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import lazy_load, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy
from lit_gpt.speed_monitor import SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from scripts.prepare_alpaca import generate_prompt

eval_interval = 600
Expand Down
2 changes: 1 addition & 1 deletion finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import lazy_load, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy
from lit_gpt.speed_monitor import SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from scripts.prepare_alpaca import generate_prompt

eval_interval = 600
Expand Down
2 changes: 1 addition & 1 deletion finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lit_gpt.model import GPT, Config, Block
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import lazy_load, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy
from lit_gpt.speed_monitor import SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from scripts.prepare_alpaca import generate_prompt

eval_interval = 600
Expand Down
2 changes: 1 addition & 1 deletion finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lit_gpt.lora import mark_only_lora_as_trainable, lora_filter, GPT, Config
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.utils import lazy_load, check_valid_checkpoint_dir, step_csv_logger, chunked_cross_entropy
from lit_gpt.speed_monitor import SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from scripts.prepare_alpaca import generate_prompt


Expand Down
94 changes: 86 additions & 8 deletions lit_gpt/speed_monitor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import time
from collections import deque
from contextlib import nullcontext
from typing import Deque, Optional
from typing import Deque, Optional, Any, Dict, Callable

import torch
from lightning import Fabric
from lightning import Fabric, Callback, Trainer, LightningModule
from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
from torch.utils.flop_counter import FlopCounterMode

from lit_gpt import GPT
Expand Down Expand Up @@ -114,7 +117,8 @@ def get_flops_available(device: torch.device, precision: str) -> Optional[float]

# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py

class SpeedMonitor:

class SpeedMonitorBase:
"""Logs the training throughput and utilization.
+-------------------------------------+-----------------------------------------------------------+
Expand Down Expand Up @@ -166,10 +170,15 @@ class SpeedMonitor:
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
"""

def __init__(self, fabric: Fabric, window_size: int = 100, time_unit: str = "hours"):
self.fabric = fabric
# TODO: this will not work properly if a precision plugin is passed to Fabric
self.flops_available = get_flops_available(fabric.device, fabric._connector._precision_input)
def __init__(
self,
flops_available: float,
log_dict: Callable[[Dict, int], None],
window_size: int = 100,
time_unit: str = "hours",
):
self.flops_available = flops_available
self.log_dict = log_dict

# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
Expand Down Expand Up @@ -260,12 +269,81 @@ def on_train_batch_end(
}
)

self.fabric.log_dict(metrics, step)
self.log_dict(metrics, step)

def eval_end(self, eval_elapsed: float):
self.total_eval_wct += eval_elapsed # seconds


class SpeedMonitorFabric(SpeedMonitorBase):
def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
# TODO: this will not work properly if a precision plugin is passed to Fabric
flops_available = get_flops_available(fabric.device, fabric._connector._precision_input)
super().__init__(flops_available, fabric.log_dict, *args, **kwargs)

@fabric_rank_zero_only
def on_train_batch_end(self, *args: Any, **kwargs: Any):
super().on_train_batch_end(*args, **kwargs)


class SpeedMonitorCallback(Callback):
def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
super().__init__()
self.speed_monitor: Optional[SpeedMonitorBase] = None
self.speed_monitor_kwargs = kwargs
self.length_fn = length_fn
self.batch_size = batch_size
self.eval_t0: int = 0
self.train_t0: int = 0
self.total_lengths: int = 0

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if self.speed_monitor is not None:
return # already setup
# TODO: this will not work properly if a precision plugin is passed to Trainer
flops_available = get_flops_available(
trainer.strategy.root_device, trainer._accelerator_connector._precision_flag
)
self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)

@trainer_rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if trainer.fit_loop._should_accumulate():
return

self.train_t0 = time.time()

@trainer_rank_zero_only
def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
) -> None:
self.total_lengths += self.length_fn(batch)
if trainer.fit_loop._should_accumulate():
return
train_elapsed = time.time() - self.train_t0
assert self.speed_monitor is not None
iter_num = trainer.fit_loop.total_batch_idx
assert (measured_flops := pl_module.measured_flops) is not None
self.speed_monitor.on_train_batch_end(
(iter_num + 1) * self.batch_size,
train_elapsed,
# this assumes that device FLOPs are the same and that all devices have the same batch size
trainer.world_size,
flops_per_batch=measured_flops,
lengths=self.total_lengths,
)

@trainer_rank_zero_only
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.eval_t0 = time.time()

@trainer_rank_zero_only
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
eval_elapsed = time.time() - self.eval_t0
assert self.speed_monitor is not None
self.speed_monitor.eval_end(eval_elapsed)


def estimate_flops(model: GPT) -> int:
"""Measures estimated FLOPs for MFU: https://arxiv.org/abs/2205.05198"""
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
Expand Down
9 changes: 6 additions & 3 deletions lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from io import BytesIO
from pathlib import Path
from types import MethodType
from typing import Optional, Any, Union, List
from typing import Optional, Any, Union, List, TypeVar, Type

import torch
import torch.utils._device
Expand Down Expand Up @@ -401,8 +401,11 @@ def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()


def step_csv_logger(*args: Any, **kwargs: Any) -> CSVLogger:
logger = CSVLogger(*args, **kwargs)
T = TypeVar("T")


def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T:
logger = cls(*args, **kwargs)

def merge_by(dicts, key):
from collections import defaultdict
Expand Down
2 changes: 1 addition & 1 deletion pretrain/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from lit_gpt import Config
from lit_gpt.model import GPT, Block
from lit_gpt.speed_monitor import SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor, measure_flops, estimate_flops
from lit_gpt.utils import step_csv_logger, chunked_cross_entropy

model_name = "pythia-70m"
Expand Down
200 changes: 200 additions & 0 deletions pretrain/openwebtext_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import math
import sys
import time
from functools import partial
from pathlib import Path
from typing import Optional, Any

import lightning as L
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.strategies import FSDPStrategy, XLAStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Config
from lit_gpt.model import GPT, Block
from lit_gpt.speed_monitor import measure_flops, estimate_flops, SpeedMonitorCallback
from lit_gpt.utils import step_csv_logger, chunked_cross_entropy

model_name = "pythia-70m"
name = "openwebtext"
out_dir = Path("out") / name
data_dir = Path("data") / name
save_interval = 1000
eval_interval = 1000
eval_iters = 100
log_interval = 1

# Hyperparameters
learning_rate = 6e-4
batch_size = 125
micro_batch_size = 5
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
max_iters = 600000 # num_epochs * (epoch_size // micro_batch_size) // devices
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
decay_lr = True
warmup_iters = 2000
lr_decay_iters = max_iters
min_lr = 6e-5

hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}


class LightningGPTModule(L.LightningModule):
def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.module: Optional[torch.nn.Module] = None
self.measured_flops: Optional[int] = None

def configure_model(self) -> None:
self.module = GPT(self.config)
self.module.apply(self.module._init_weights)

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.AdamW(
self.module.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)

def on_fit_start(self) -> None:
trainer = self.trainer
with torch.device("meta"):
meta_model = GPT(self.module.config)
# estimated is too much of an optimistic estimate, left just for reference
estimated_flops = estimate_flops(meta_model) * micro_batch_size
self.print(f"Estimated TFLOPs: {estimated_flops * trainer.world_size / 1e12:.2f}")
x = torch.randint(0, 1, (micro_batch_size, meta_model.config.block_size))
self.measured_flops = measure_flops(meta_model, x)
self.print(f"Measured TFLOPs: {self.measured_flops * trainer.world_size / 1e12:.2f}")

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
if not decay_lr:
return
# determine and set the learning rate for this iteration
lr = get_lr(self.trainer.fit_loop.total_batch_idx)
for optimizer in self.trainer.strategy.optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] = lr

def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
input_ids, targets = batch
logits = self.module(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True)
return loss

def validation_step(self, batch: Any, batch_idx: int) -> None:
input_ids, targets = batch
logits = self.module(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)


def main(devices: int = 1, precision: Optional[str] = None, tpu: bool = False) -> None:
if precision is None:
precision = "32-true" if tpu else "bf16-mixed"
if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
activation_checkpointing=Block,
# the argument is not available in the Trainer strategy, but it's the default anyways
# state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"

logger = step_csv_logger("out", name, cls=CSVLogger, flush_logs_every_n_steps=log_interval)
speed_monitor = SpeedMonitorCallback(
length_fn=lambda batch: batch[0].size(1), batch_size=micro_batch_size, window_size=50, time_unit="seconds"
)
model_checkpoint = ModelCheckpoint(dirpath=out_dir, every_n_train_steps=save_interval, save_last=True, verbose=True)
trainer = L.Trainer(
devices=devices,
strategy=strategy,
precision=precision,
logger=logger,
callbacks=[speed_monitor, model_checkpoint],
max_steps=max_iters,
max_epochs=1,
limit_val_batches=eval_iters,
accumulate_grad_batches=gradient_accumulation_steps,
log_every_n_steps=log_interval,
val_check_interval=eval_interval,
)

L.seed_everything(1337 + trainer.global_rank, workers=True)

trainer.print(hparams)

if trainer.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)

config = Config.from_name(model_name)
trainer.print(f"Loading model with {config.__dict__}")
t0 = time.time()
model = LightningGPTModule(config)
trainer.print(f"Time to instantiate model: {time.time() - t0:.02f} seconds.")

train_data = Dataset(str(data_dir / "train.bin"), config.block_size)
val_data = Dataset(str(data_dir / "val.bin"), config.block_size)

t0 = time.time()
trainer.fit(model, train_data, val_data, ckpt_path="last")
trainer.print(f"Training time: {(time.time()-t0):.2f}s")


class Dataset:
def __init__(self, bin: str, block_size: int) -> None:
self.data = np.memmap(bin, dtype=np.uint16, mode="r")
self.block_size = block_size

def __iter__(self):
while True:
ix = torch.randint(len(self.data) - self.block_size, (micro_batch_size,))
x = torch.stack([torch.from_numpy((self.data[i : i + self.block_size]).astype(np.int64)) for i in ix])
y = torch.stack(
[torch.from_numpy((self.data[i + 1 : i + 1 + self.block_size]).astype(np.int64)) for i in ix]
)
yield x, y


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)


if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")

from jsonargparse import CLI

CLI(main)
Loading

0 comments on commit f50e2d8

Please sign in to comment.