Skip to content

Commit

Permalink
Data Refactor - Data Modules as Dataclasses (Lightning-AI#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 15, 2024
1 parent c5f1f83 commit 57036f0
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 298 deletions.
71 changes: 28 additions & 43 deletions lit_gpt/data/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Implementation derived from https://github.com/tloen/alpaca-lora"""

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Dict

Expand All @@ -14,38 +15,32 @@
_URL = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"


@dataclass
class Alpaca(LitDataModule):
"""Alpaca data module for supervised finetuning.
Provides train- and val-dataloaders. The batches return keys "input_ids" and "labels".
"""

def __init__(
self,
mask_prompt: bool = False,
test_split_fraction: float = 0.03865, # to get exactly 2000 test samples,
ignore_index: int = -1,
seed: int = 42,
num_workers: int = 4,
data_file_url: str = _URL,
data_file_name: str = "alpaca_data_cleaned_archive.json",
download_dir: Path = Path("./data/alpaca"),
) -> None:
super().__init__()
self.mask_prompt = mask_prompt
self.test_split_fraction = test_split_fraction
self.ignore_index = ignore_index
self.seed = seed
self.num_workers = num_workers
self.data_file_url = data_file_url
self.data_file_name = data_file_name
self.download_dir = download_dir

self.tokenizer: Optional[Tokenizer] = None
self.batch_size: int = 1
self.max_seq_length: int = -1
self.train_dataset: Optional[SFTDataset] = None
self.test_dataset: Optional[SFTDataset] = None
"""Alpaca data module for supervised finetuning."""

mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
test_split_fraction: float = 0.03865 # to get exactly 2000 test samples,
"""The fraction of the dataset to use for the test/validation dataset. The rest is used for training."""
ignore_index: int = -1
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/alpaca")
"""The directory in which the downloaded dataset gets saved."""
file_url: str = field(repr=False, default=_URL)
"""The URL from where to download the dataset."""
file_name: str = field(repr=False, default="alpaca_data_cleaned_archive.json")
"""The name of the dataset file to download."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)

def connect(
self,
Expand All @@ -59,10 +54,10 @@ def connect(

def prepare_data(self) -> None:
self.download_dir.mkdir(parents=True, exist_ok=True)
download_if_missing(self.download_dir / self.data_file_name, self.data_file_url)
download_if_missing(self.download_dir / self.file_name, self.file_url)

def setup(self, stage: str = "") -> None:
with open(self.download_dir / self.data_file_name, "r", encoding="utf-8") as file:
with open(self.download_dir / self.file_name, "r", encoding="utf-8") as file:
data = json.load(file)

# Partition the dataset into train and test
Expand Down Expand Up @@ -109,16 +104,6 @@ def val_dataloader(self) -> DataLoader:
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"mask_prompt={self.mask_prompt}, "
f"test_split_fraction={self.test_split_fraction}, "
f"seed={self.seed}, "
f"num_workers={self.num_workers}, "
"...)"
)


def download_if_missing(file_path: Path, file_url: str) -> None:
"""Downloads the raw json data file and saves it in the given destination."""
Expand Down
56 changes: 20 additions & 36 deletions lit_gpt/data/dolly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
from dataclasses import dataclass, field
from pathlib import Path

import torch
Expand All @@ -11,36 +12,29 @@
_URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"


@dataclass
class Dolly(Alpaca):
"""Dolly data module for supervised finetuning.
"""Dolly data module for supervised finetuning."""

Provides train- and val-dataloaders. The batches return keys "input_ids" and "labels".
"""

def __init__(
self,
mask_prompt: bool = False,
test_split_fraction: float = 0.1,
ignore_index: int = -1,
seed: int = 42,
num_workers: int = 4,
data_file_url: str = _URL,
data_file_name: str = "dolly_data_cleaned.json",
download_dir: Path = Path("./data/dolly"),
) -> None:
super().__init__(
mask_prompt=mask_prompt,
test_split_fraction=test_split_fraction,
ignore_index=ignore_index,
seed=seed,
num_workers=num_workers,
data_file_url=data_file_url,
data_file_name=data_file_name,
download_dir=download_dir,
)
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
test_split_fraction: float = 0.1
"""The fraction of the dataset to use for the test/validation dataset. The rest is used for training."""
ignore_index: int = -1
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/dolly")
"""The directory in which the downloaded dataset gets saved."""
file_url: str = field(repr=False, default=_URL)
"""The URL from where to download the dataset."""
file_name: str = field(repr=False, default="dolly_data_cleaned.json")
"""The name of the dataset file to download."""

def setup(self, stage: str = "") -> None:
with open(self.download_dir / self.data_file_name, "r", encoding="utf-8") as file:
with open(self.download_dir / self.file_name, "r", encoding="utf-8") as file:
data = file.readlines()
data = [json.loads(line) for line in data]
for item in data:
Expand Down Expand Up @@ -71,13 +65,3 @@ def setup(self, stage: str = "") -> None:
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"mask_prompt={self.mask_prompt}, "
f"test_split_fraction={self.test_split_fraction}, "
f"seed={self.seed}, "
f"num_workers={self.num_workers}, "
"...)"
)
73 changes: 29 additions & 44 deletions lit_gpt/data/flan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Dict, List, Set

Expand All @@ -15,47 +16,41 @@

# TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache
# on disk or use Lightning Data
@dataclass
class FLAN(LitDataModule):
"""FLAN data module for supervised finetuning.
Provides train- and val-dataloaders. The batches return keys "input_ids" and "labels".
"""

def __init__(
self,
mask_prompt: bool = False,
test_split_fraction: float = 0.03865, # to get exactly 2000 test samples,
ignore_index: int = -1,
seed: int = 42,
num_workers: int = 4,
data_url: str = _URL,
download_dir: Path = Path("./data/flan"),
subsets: Optional[str] = None,
) -> None:
super().__init__()
self.mask_prompt = mask_prompt
self.test_split_fraction = test_split_fraction
self.ignore_index = ignore_index
self.seed = seed
self.num_workers = num_workers
self.data_url = data_url
self.download_dir = download_dir

"""FLAN data module for supervised finetuning."""

mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
ignore_index: int = -1
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""
download_dir: Path = Path("./data/flan")
"""The directory in which the downloaded dataset gets saved."""
url: str = _URL
"""The URL from where to download the dataset."""
subsets: Optional[str] = None
"""A comma separated list of subsets to use. If None, all subsets are used."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)

def __post_init__(self):
supported_subsets = _supported_subsets()
if subsets is not None:
self.subsets = subsets.split(",")
if self.subsets is not None:
self.subsets = self.subsets.split(",")
for subset in self.subsets:
if subset not in supported_subsets:
raise ValueError(f"{subset} not in {supported_subsets}")
else:
self.subsets = list(supported_subsets)

self.tokenizer: Optional[Tokenizer] = None
self.batch_size: int = 1
self.max_seq_length: int = -1
self.train_dataset: Optional[SFTDataset] = None
self.test_dataset: Optional[SFTDataset] = None

def connect(
self,
tokenizer: Optional[Tokenizer] = None,
Expand All @@ -71,7 +66,7 @@ def prepare_data(self) -> None:
for subset in self.subsets:
for split in ("train", "test"):
data_file_path = self.download_dir / f"{subset}_{split}.jsonl"
data_file_url = f"{self.data_url}/{split}/{subset}_{split}.jsonl"
data_file_url = f"{self.url}/{split}/{subset}_{split}.jsonl"
download_if_missing(data_file_path, data_file_url)

def train_dataloader(self):
Expand Down Expand Up @@ -106,16 +101,6 @@ def _dataloader(self, split: str) -> DataLoader:
collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index)
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"mask_prompt={self.mask_prompt}, "
f"test_split_fraction={self.test_split_fraction}, "
f"seed={self.seed}, "
f"num_workers={self.num_workers}, "
"...)"
)


def load_jsonl(filename: Path) -> List[Dict[str, str]]:
data = []
Expand Down
76 changes: 26 additions & 50 deletions lit_gpt/data/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

Expand All @@ -11,46 +12,32 @@
from lit_gpt.tokenizer import Tokenizer


@dataclass
class JSON(LitDataModule):
"""Loads JSON data for supervised finetuning.
Provides train- and val-dataloaders. The batches return keys "input_ids" and "labels".
Args:
json_path: A path to a JSON file containing the data. The file should contain a list of samples (dicts).
Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input'
(see Alpaca).
mask_prompt: Whether to mask the prompt section from the label (with ``ignore_index``).
test_split_fraction: A number in the range [0, 1] that determines the fraction of the dataset
to use for testing.
ignore_index: The index to use for elements to be ignored in the label.
seed: The random seed for creating the train/val splits and shuffling the dataset.
num_workers: How many DataLoader processes to use for loading.
"""

def __init__(
self,
json_path: Path,
mask_prompt: bool = False,
test_split_fraction: float = 0.1,
ignore_index: int = -1,
seed: int = 42,
num_workers: int = 4,
) -> None:
super().__init__()
self.json_path = json_path
self.mask_prompt = mask_prompt
self.test_split_fraction = test_split_fraction
self.ignore_index = ignore_index
self.seed = seed
self.num_workers = num_workers

self.tokenizer: Optional[Tokenizer] = None
self.batch_size: int = 1
self.max_seq_length: int = -1
self.train_dataset: Optional[SFTDataset] = None
self.test_dataset: Optional[SFTDataset] = None

"""Loads JSON data for supervised finetuning."""

json_path: Path
"""A path to a JSON file containing the data. The file should contain a list of samples (dicts).
Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input'
(see Alpaca)."""
mask_prompt: bool = False
"""Whether to mask the prompt section from the label (with ``ignore_index``)."""
test_split_fraction: float = 0.1
"""The fraction of the dataset to use for the test/validation dataset. The rest is used for training."""
ignore_index: int = -1
"""The index to use for elements to be ignored in the label."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
num_workers: int = 4
"""How many DataLoader processes to use for loading."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)

def __post_init__(self):
if not self.json_path.is_file():
raise FileNotFoundError(f"The file {self.json_path} does not exist.")

Expand Down Expand Up @@ -93,17 +80,6 @@ def setup(self, stage: str = "") -> None:
ignore_index=self.ignore_index,
)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"json_path={self.json_path},"
f"mask_prompt={self.mask_prompt}, "
f"test_split_fraction={self.test_split_fraction}, "
f"seed={self.seed}, "
f"num_workers={self.num_workers}, "
"...)"
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
Expand Down
Loading

0 comments on commit 57036f0

Please sign in to comment.