Skip to content

Commit

Permalink
Type hints (Lightning-AI#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Oct 11, 2023
1 parent 1ec6b75 commit bf60124
Show file tree
Hide file tree
Showing 24 changed files with 73 additions and 60 deletions.
2 changes: 1 addition & 1 deletion chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def generate(
temperature: float = 1.0,
top_k: Optional[int] = None,
stop_tokens: Tuple[List[int], ...] = (),
):
) -> Iterator[torch.Tensor]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as possible.
Args:
Expand Down
6 changes: 3 additions & 3 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter/alpaca"),
precision: Optional[str] = None,
):
) -> None:
precision = precision or get_default_supported_precision(training=True)

fabric_devices = devices
Expand All @@ -75,7 +75,7 @@ def setup(
fabric.launch(main, data_dir, checkpoint_dir, out_dir)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def save_adapter_checkpoint(fabric, model, file_path: Path):
def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving adapter weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})

Expand Down
6 changes: 3 additions & 3 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter_v2/alpaca"),
precision: Optional[str] = None,
):
) -> None:
precision = precision or get_default_supported_precision(training=True)

fabric_devices = devices
Expand All @@ -75,7 +75,7 @@ def setup(
fabric.launch(main, data_dir, checkpoint_dir, out_dir)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def save_adapter_v2_checkpoint(fabric, model, file_path: Path):
def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving adapter v2 weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": adapter_filter})

Expand Down
6 changes: 3 additions & 3 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def setup(
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/full/alpaca"),
precision: Optional[str] = None,
):
) -> None:
precision = precision or get_default_supported_precision(training=True)

fabric_devices = devices
Expand All @@ -75,7 +75,7 @@ def setup(
fabric.launch(main, data_dir, checkpoint_dir, out_dir)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -270,7 +270,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def save_checkpoint(fabric, model, file_path: Path):
def save_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model})

Expand Down
8 changes: 4 additions & 4 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def setup(
out_dir: Path = Path("out/lora/alpaca"),
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
):
) -> None:
precision = precision or get_default_supported_precision(training=True)

plugins = None
Expand Down Expand Up @@ -96,7 +96,7 @@ def setup(
fabric.launch(main, data_dir, checkpoint_dir, out_dir, quantize)


def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, quantize: Optional[str] = None):
def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path, quantize: Optional[str] = None) -> None:
check_valid_checkpoint_dir(checkpoint_dir)

speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")
Expand Down Expand Up @@ -142,7 +142,7 @@ def main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path,
else:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay)
optimizer = fabric.setup_optimizers(optimizer)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters//batch_size)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters // batch_size)

# strict=False because missing keys due to LoRA weights not contained in state dict
load_checkpoint(fabric, model, checkpoint_path, strict=False)
Expand Down Expand Up @@ -320,7 +320,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
return longest_seq_length, longest_seq_ix


def save_lora_checkpoint(fabric, model, file_path: Path):
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
fabric.save(file_path, {"model": model}, filter={"model": lora_filter})

Expand Down
2 changes: 1 addition & 1 deletion lit_gpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, config: Config, block_idx: int) -> None:

def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
):
) -> torch.Tensor:
y = super().scaled_dot_product_attention(q, k, v, mask)
if self.block_idx < self.config.adapter_start_layer:
return y
Expand Down
8 changes: 4 additions & 4 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ def __init__(
self.scaling = self.lora_alpha / self.r
self.reset_parameters()

def reset_parameters(self):
def reset_parameters(self) -> None:
"""Reset all the weights, even including pretrained ones."""
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)

def merge(self):
def merge(self) -> None:
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
if self.r > 0 and not self.merged:
# Merge the weights and mark it
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True

def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
pretrained = self.linear(x)
Expand Down Expand Up @@ -330,7 +330,7 @@ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
) # (B, C_output, T)

def merge(self):
def merge(self) -> None:
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""

# Let's assume that:
Expand Down
2 changes: 1 addition & 1 deletion lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def forward(

def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
):
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.head_size)
if (
FlashAttention2Available
Expand Down
2 changes: 1 addition & 1 deletion lit_gpt/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.weight * x_normed

def reset_parameters(self):
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
6 changes: 3 additions & 3 deletions lit_gpt/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def on_train_batch_end(
world_size: int,
flops_per_batch: Optional[int] = None, # (per device)
lengths: Optional[int] = None, # total length of the samples seen (per device)
):
) -> None:
self.step += 1
step = self.step
metrics = {}
Expand Down Expand Up @@ -291,7 +291,7 @@ def on_train_batch_end(

self.log_dict(metrics, step)

def eval_end(self, eval_elapsed: float):
def eval_end(self, eval_elapsed: float) -> None:
self.total_eval_wct += eval_elapsed # seconds


Expand Down Expand Up @@ -322,7 +322,7 @@ def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
super().__init__(flops_available, fabric.log_dict, *args, **kwargs)

@fabric_rank_zero_only
def on_train_batch_end(self, *args: Any, **kwargs: Any):
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
super().on_train_batch_end(*args, **kwargs)


Expand Down
2 changes: 1 addition & 1 deletion lit_gpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def token_to_id(self, token: str) -> int:
raise ValueError(f"token {token!r} not found in the collection.")
return id_

def check_if_bos_token_used(self, checkpoint_dir) -> bool:
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
return False
with open(tokenizer_config_path) as fp:
Expand Down
3 changes: 2 additions & 1 deletion lit_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import ContextManager, Dict, List, Mapping, Optional, TypeVar, Union

import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
Expand Down Expand Up @@ -301,7 +302,7 @@ def get_default_supported_precision(training: bool) -> str:
return "bf16-mixed" if training else "bf16-true"


def load_checkpoint(fabric, model, checkpoint_path: Path, strict: bool = True) -> None:
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
else:
Expand Down
18 changes: 12 additions & 6 deletions pretrain/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import time
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Tuple, Union

import lightning as L
import numpy as np
Expand All @@ -17,8 +17,8 @@

from lit_gpt import Config
from lit_gpt.model import GPT, Block
from lit_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters

model_name = "pythia-70m"
Expand Down Expand Up @@ -69,7 +69,7 @@ def setup(devices: int = 1, precision: Optional[str] = None, resume: Union[bool,
fabric.launch(main, resume=resume)


def main(fabric, resume) -> None:
def main(fabric: L.Fabric, resume: bool) -> None:
speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")

if fabric.global_rank == 0:
Expand Down Expand Up @@ -113,7 +113,13 @@ def main(fabric, resume) -> None:
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


def train(fabric, state, train_dataloader, val_dataloader, speed_monitor):
def train(
fabric: L.Fabric,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
speed_monitor: SpeedMonitorBase,
) -> None:
model = state["model"]
optimizer = state["optimizer"]

Expand Down Expand Up @@ -204,7 +210,7 @@ def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoade
return out


def load_datasets(data_dir: Path, max_seq_length: int):
def load_datasets(data_dir: Path, max_seq_length: int) -> Tuple["Dataset", "Dataset"]:
train_data = Dataset(data_dir / "train.bin", max_seq_length)
val_data = Dataset(data_dir / "val.bin", max_seq_length)
return train_data, val_data
Expand All @@ -226,7 +232,7 @@ def __iter__(self):


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(it: int) -> float:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand Down
2 changes: 1 addition & 1 deletion pretrain/openwebtext_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __iter__(self):


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(it: int) -> float:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand Down
18 changes: 12 additions & 6 deletions pretrain/redpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from lit_gpt.model import GPT, Block, Config
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops
from lit_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters

model_name = "Llama-2-7b-hf"
Expand Down Expand Up @@ -86,7 +86,7 @@ def setup(
fabric.launch(main, train_data_dir, val_data_dir, resume)


def main(fabric, train_data_dir, val_data_dir, resume):
def main(fabric: L.Fabric, train_data_dir: Path, val_data_dir: Path, resume: bool) -> None:
speed_monitor = SpeedMonitor(fabric, window_size=50, time_unit="seconds")

if fabric.global_rank == 0:
Expand Down Expand Up @@ -139,7 +139,13 @@ def main(fabric, train_data_dir, val_data_dir, resume):
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


def train(fabric, state, train_dataloader, val_dataloader, speed_monitor):
def train(
fabric: L.Fabric,
state: dict,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
speed_monitor: SpeedMonitorBase,
) -> None:
model = state["model"]
optimizer = state["optimizer"]

Expand Down Expand Up @@ -234,7 +240,7 @@ def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoade


def create_dataloader(
batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345
batch_size: int, block_size: int, data_dir: Path, fabric: L.Fabric, shuffle: bool = True, seed: int = 12345
) -> DataLoader:
datasets = []
for prefix, _ in data_config:
Expand Down Expand Up @@ -267,7 +273,7 @@ def create_dataloader(
def create_dataloaders(
batch_size: int,
block_size: int,
fabric,
fabric: L.Fabric,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
seed: int = 12345,
Expand Down Expand Up @@ -298,7 +304,7 @@ def create_dataloaders(


# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(it: int) -> float:
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def copy_weights_llama(
state_dict: Dict[str, torch.Tensor],
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
):
) -> None:
weight_map = {
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
Expand Down
Loading

0 comments on commit bf60124

Please sign in to comment.