forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add general purpose LitData streaming data module (Lightning-AI#1118)
- Loading branch information
Showing
19 changed files
with
157 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
import os | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Union, Optional, Tuple | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from litgpt import Tokenizer | ||
from litgpt.data import DataModule | ||
|
||
|
||
@dataclass | ||
class LitData(DataModule): | ||
"""Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks).""" | ||
|
||
data_path: Union[str, Path] = Path("data/") | ||
"""The path to the data directory containing the preprocessed chunks for the streaming dataset | ||
The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders | ||
for training- and validation splits.""" | ||
split_names: Optional[Tuple[str, str]] = None | ||
"""Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided, | ||
all data under data_path will be used for training, and the validation dataloader will be identical to the | ||
train dataloader.""" | ||
seed: int = 42 | ||
"""The random seed for shuffling the dataset.""" | ||
num_workers: int = 8 | ||
"""How many DataLoader processes to use for loading.""" | ||
|
||
batch_size: int = field(init=False, repr=False, default=1) | ||
seq_length: int = field(init=False, repr=False, default=2048) | ||
|
||
def __post_init__(self) -> None: | ||
if self.split_names is not None and len(self.split_names) != 2: | ||
raise ValueError( | ||
"If provided `split_names` must be a tuple of two strings, for example: ('train', 'val')." | ||
) | ||
|
||
def connect( | ||
self, | ||
tokenizer: Optional[Tokenizer] = None, | ||
batch_size: int = 1, | ||
max_seq_length: Optional[int] = None | ||
) -> None: | ||
self.batch_size = batch_size | ||
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path) | ||
return self._dataloader(input_dir=input_dir, train=True) | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path) | ||
return self._dataloader(input_dir=input_dir, train=False) | ||
|
||
def _dataloader(self, input_dir: str, train: bool): | ||
from litdata.streaming import StreamingDataset, TokensLoader | ||
|
||
dataset = StreamingDataset( | ||
input_dir=input_dir, | ||
item_loader=TokensLoader(block_size=self.seq_length), | ||
shuffle=train, | ||
drop_last=True, | ||
) | ||
dataloader = DataLoader( | ||
dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True | ||
) | ||
return dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.