Skip to content

Commit

Permalink
Add general purpose LitData streaming data module (Lightning-AI#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent 29aaf15 commit b6c97e4
Show file tree
Hide file tree
Showing 19 changed files with 157 additions and 51 deletions.
6 changes: 4 additions & 2 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from litgpt.data.base import LitDataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.alpaca import Alpaca
from litgpt.data.alpaca_2k import Alpaca2k
from litgpt.data.alpaca_gpt4 import AlpacaGPT4
Expand All @@ -9,6 +9,7 @@
from litgpt.data.dolly import Dolly
from litgpt.data.flan import FLAN
from litgpt.data.lima import LIMA
from litgpt.data.lit_data import LitData
from litgpt.data.longform import LongForm
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
Expand All @@ -24,7 +25,8 @@
"FLAN",
"JSON",
"LIMA",
"LitDataModule",
"LitData",
"DataModule",
"LongForm",
"OpenWebText",
"SFTDataset",
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import torch
from torch.utils.data import random_split, DataLoader
from lightning_utilities.core.imports import RequirementCache
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer

_URL = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"


@dataclass
class Alpaca(LitDataModule):
class Alpaca(DataModule):
"""Alpaca data module for supervised finetuning."""

mask_prompt: bool = False
Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from litgpt.prompts import PromptStyle


class LitDataModule(LightningDataModule):
class DataModule(LightningDataModule):
"""Base class for all data modules in LitGPT."""

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/deita.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from torch.utils.data import DataLoader

from litgpt import PromptStyle
from litgpt.data import LitDataModule, SFTDataset, get_sft_collate_fn
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.tokenizer import Tokenizer


@dataclass
class Deita(LitDataModule):
class Deita(DataModule):
"""Deita data module for supervised finetuning."""

mask_prompt: bool = False
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/flan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.tokenizer import Tokenizer

Expand All @@ -19,7 +19,7 @@
# TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache
# on disk or use Lightning Data
@dataclass
class FLAN(LitDataModule):
class FLAN(DataModule):
"""FLAN data module for supervised finetuning."""

mask_prompt: bool = False
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/json_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from torch.utils.data import random_split, DataLoader

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.tokenizer import Tokenizer


@dataclass
class JSON(LitDataModule):
class JSON(DataModule):
"""Loads JSON or JSONL data for supervised finetuning."""

json_path: Path
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/lima.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from torch.utils.data import random_split, DataLoader

from litgpt import PromptStyle
from litgpt.data import LitDataModule, SFTDataset, get_sft_collate_fn
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.tokenizer import Tokenizer


@dataclass
class LIMA(LitDataModule):
class LIMA(DataModule):
"""LIMA data module for supervised finetuning."""

mask_prompt: bool = False
Expand Down
68 changes: 68 additions & 0 deletions litgpt/data/lit_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union, Optional, Tuple

from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import DataModule


@dataclass
class LitData(DataModule):
"""Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks)."""

data_path: Union[str, Path] = Path("data/")
"""The path to the data directory containing the preprocessed chunks for the streaming dataset
The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders
for training- and validation splits."""
split_names: Optional[Tuple[str, str]] = None
"""Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided,
all data under data_path will be used for training, and the validation dataloader will be identical to the
train dataloader."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""

batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)

def __post_init__(self) -> None:
if self.split_names is not None and len(self.split_names) != 2:
raise ValueError(
"If provided `split_names` must be a tuple of two strings, for example: ('train', 'val')."
)

def connect(
self,
tokenizer: Optional[Tokenizer] = None,
batch_size: int = 1,
max_seq_length: Optional[int] = None
) -> None:
self.batch_size = batch_size
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def train_dataloader(self) -> DataLoader:
input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path)
return self._dataloader(input_dir=input_dir, train=True)

def val_dataloader(self) -> DataLoader:
input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path)
return self._dataloader(input_dir=input_dir, train=False)

def _dataloader(self, input_dir: str, train: bool):
from litdata.streaming import StreamingDataset, TokensLoader

dataset = StreamingDataset(
input_dir=input_dir,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=train,
drop_last=True,
)
dataloader = DataLoader(
dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return dataloader
4 changes: 2 additions & 2 deletions litgpt/data/longform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.tokenizer import Tokenizer

Expand All @@ -18,7 +18,7 @@


@dataclass
class LongForm(LitDataModule):
class LongForm(DataModule):
"""LongForm data module for supervised finetuning."""

mask_prompt: bool = False
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import LitDataModule
from litgpt.data import DataModule


@dataclass
class OpenWebText(LitDataModule):
class OpenWebText(DataModule):
"""The OpenWebText data module for pretraining."""

data_path: Union[str, Path] = Path("data/openwebtext")
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import LitDataModule
from litgpt.data import DataModule


@dataclass
class TinyLlama(LitDataModule):
class TinyLlama(DataModule):
"""The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data.
Provides training and validation streaming dataloaders that return batches of tokens.
Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from tqdm import tqdm

from litgpt.data.alpaca import download_if_missing
from litgpt.data.base import LitDataModule
from litgpt.data.base import DataModule
from litgpt.tokenizer import Tokenizer


@dataclass
class TinyStories(LitDataModule):
class TinyStories(DataModule):
"""The TinyStories data module: https://huggingface.co/datasets/roneneldan/TinyStories
Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length.
Expand Down
12 changes: 6 additions & 6 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, LitDataModule
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
Expand All @@ -39,7 +39,7 @@ def setup(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
seed: int = 1337,
data: Optional[LitDataModule] = None,
data: Optional[DataModule] = None,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/finetune/adapter"),
train: TrainArgs = TrainArgs(
Expand Down Expand Up @@ -90,7 +90,7 @@ def setup(
fabric.launch(main, devices, seed, Config.from_name(name=checkpoint_dir.name), data, checkpoint_dir, out_dir, train, eval)


def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(train, eval)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down Expand Up @@ -160,7 +160,7 @@ def fit(
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: LitDataModule,
data: DataModule,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
Expand Down Expand Up @@ -239,7 +239,7 @@ def fit(
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: LitDataModule,
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule,
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
Expand Down Expand Up @@ -278,7 +278,7 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
def get_dataloaders(fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
Expand Down
12 changes: 6 additions & 6 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, LitDataModule
from litgpt.data import Alpaca, DataModule
from litgpt.generate.base import generate
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
Expand All @@ -39,7 +39,7 @@ def setup(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
seed: int = 1337,
data: Optional[LitDataModule] = None,
data: Optional[DataModule] = None,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/adapter_v2"),
train: TrainArgs = TrainArgs(
Expand Down Expand Up @@ -90,7 +90,7 @@ def setup(
fabric.launch(main, devices, seed, Config.from_name(name=checkpoint_dir.name), data, checkpoint_dir, out_dir, train, eval)


def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
validate_args(train, eval)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down Expand Up @@ -160,7 +160,7 @@ def fit(
out_dir: Path,
train: TrainArgs,
eval: EvalArgs,
data: LitDataModule,
data: DataModule,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset)
Expand Down Expand Up @@ -239,7 +239,7 @@ def fit(
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def validate(
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: LitDataModule,
fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule,
) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
Expand Down Expand Up @@ -278,7 +278,7 @@ def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):
return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])


def get_dataloaders(fabric: L.Fabric, data: LitDataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
def get_dataloaders(fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs) -> Tuple[DataLoader, DataLoader]:
data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)
with fabric.rank_zero_first():
data.prepare_data()
Expand Down
Loading

0 comments on commit b6c97e4

Please sign in to comment.