Skip to content

Commit 4717427

Browse files
authored
Refactor GPT data (ServiceNow#35)
1 parent da0eff0 commit 4717427

21 files changed

+492
-451
lines changed

fast_llm/data/dataset.py fast_llm/data/blended.py

+9-65
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import abc
21
import logging
32
import pathlib
43
import time
54

65
import numpy as np
7-
import torch.utils.data
86

97
from fast_llm.core.distributed import ProcessGroup, safe_barrier
8+
from fast_llm.data.config import SampledDataset
109
from fast_llm.engine.config_utils.run import log_main_rank
1110
from fast_llm.utils import Assert
1211

@@ -20,43 +19,6 @@
2019
logger = logging.getLogger(__name__)
2120

2221

23-
class Dataset(abc.ABC):
24-
"""
25-
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
26-
"""
27-
28-
@abc.abstractmethod
29-
def __getitem__(self, index: int):
30-
pass
31-
32-
@abc.abstractmethod
33-
def __len__(self):
34-
pass
35-
36-
@property
37-
@abc.abstractmethod
38-
def name(self):
39-
"""
40-
A name for the dataset to facilitate identification and debugging.
41-
"""
42-
43-
44-
class RawDataset(Dataset): # noqa
45-
"""
46-
A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk.
47-
(Excluding off-line processing prior to training.)
48-
Functionally identical to a `Dataset`, but renamed for clarity.
49-
"""
50-
51-
52-
class SampledDataset(Dataset): # noqa
53-
"""
54-
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
55-
(See the `Sampler` class below.)
56-
Functionally identical to a `Dataset`, but renamed for clarity.
57-
"""
58-
59-
6022
class BlendedDataset(SampledDataset):
6123
"""
6224
A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability.
@@ -72,7 +34,7 @@ def __init__(
7234
*,
7335
name: str = "blended",
7436
num_samples: int,
75-
cache_dir: pathlib.Path | None = None,
37+
cache_directory: pathlib.Path | None = None,
7638
group: ProcessGroup | None = None,
7739
verbose: bool = True,
7840
data_sample_warn_time_ms: float = 1000,
@@ -83,19 +45,20 @@ def __init__(
8345
self._weights = weights
8446
self._data_sample_warn_time_ms = data_sample_warn_time_ms
8547

86-
if cache_dir is None:
48+
if cache_directory is None:
8749
self._dataset_idx_filename, self._sample_idx_filename = None, None
8850
self._dataset_index, self._sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
8951
else:
90-
self._dataset_idx_filename = cache_dir / (self._name + "_blending_dataset_idx.npy")
91-
self._sample_idx_filename = cache_dir / (self._name + "_blending_sample_idx.npy")
52+
self._dataset_idx_filename = cache_directory / (self._name + "_blending_dataset_idx.npy")
53+
self._sample_idx_filename = cache_directory / (self._name + "_blending_sample_idx.npy")
9254

9355
# Build the indexed mapping if it doesn't exist.
9456
# TODO: This only works if the dataset location is accessible by all job.
9557
if (group is None or group.rank() == 0) and not (
9658
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
9759
):
9860
dataset_index, sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
61+
cache_directory.mkdir(exist_ok=True, parents=True)
9962
np.save(self._dataset_idx_filename, dataset_index)
10063
np.save(self._sample_idx_filename, sample_index)
10164

@@ -140,7 +103,9 @@ def __len__(self):
140103
return self._num_samples
141104

142105
def _build_blending_indices(self, verbose: bool):
143-
assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first."
106+
assert _extension_available, (
107+
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
108+
)
144109
Assert.lt(len(self._datasets), 32767)
145110
dataset_index = np.zeros(self._num_samples, dtype=np.int16)
146111
dataset_sample_index = np.zeros(self._num_samples, dtype=np.int64)
@@ -191,24 +156,3 @@ def __getitem__(self, idx):
191156
@property
192157
def name(self):
193158
return self._name
194-
195-
196-
class Sampler(torch.utils.data.Sampler):
197-
"""
198-
A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
199-
To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
200-
"""
201-
202-
def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
203-
self._total_samples = total_samples
204-
self._begin_index = begin_index
205-
self._batch_size = micro_batch_size * data_parallel
206-
self._start_idx = data_rank * micro_batch_size
207-
self._end_idx = (data_rank + 1) * micro_batch_size
208-
209-
def __len__(self):
210-
return self._total_samples
211-
212-
def __iter__(self):
213-
for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
214-
yield list(range(idx + self._start_idx, idx + self._end_idx))

fast_llm/data/config.py

+33-47
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ class TokenizerConfig(Config):
123123

124124

125125
@config_class()
126-
class AbstractDataConfig(Config):
126+
class DataConfig(Config):
127127
_abstract = True
128128

129129

130-
class AbstractData(abc.ABC):
130+
class Data(abc.ABC):
131131
# TODO: Improve interface
132132
@abc.abstractmethod
133133
def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]):
@@ -146,52 +146,38 @@ def get_iterator(
146146
pass
147147

148148

149-
@config_class()
150-
class DataConfig(AbstractDataConfig):
149+
class Dataset(abc.ABC):
151150
"""
152-
Configuration for the dataset(s), split and sampling.
153-
Currently hard-coded to a GPT dataset.
154-
TODO: Extract generalizable content.
151+
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
155152
"""
156153

157-
_abstract = False
154+
@abc.abstractmethod
155+
def __getitem__(self, index: int):
156+
pass
158157

159-
tokenizer: TokenizerConfig = Field(
160-
default_factory=TokenizerConfig,
161-
desc="Configuration for the tokenizer (for FIM).",
162-
hint=FieldHint.feature,
163-
)
164-
fim: FimConfig = Field(
165-
default_factory=FimConfig,
166-
desc="Configuration for Fill In the Middle (FIM).",
167-
hint=FieldHint.feature,
168-
)
169-
# TODO: set default to [1,0,0]?
170-
split: list[float] = Field(
171-
default_factory=lambda: [969, 30, 1],
172-
desc="Split ratio for train, valid and test datasets.",
173-
hint=FieldHint.core,
174-
valid=_validate_split,
175-
)
176-
format: DatasetSource = Field(
177-
default=DatasetSource.list,
178-
desc="Format for the dataset definition.",
179-
hint=FieldHint.core,
180-
)
181-
path: list[str] = Field(
182-
default_factory=list,
183-
desc="Path or list of paths and weights.",
184-
hint=FieldHint.core,
185-
valid=_validate_path,
186-
)
187-
data_sample_warn_time_ms: float = Field(
188-
default=1000,
189-
desc="Warn if a sample takes too long to load.",
190-
hint=FieldHint.feature,
191-
valid=check_field(Assert.gt, 0),
192-
)
193-
multiprocessing_context: MultiprocessingContext = Field(
194-
default=MultiprocessingContext.spawn,
195-
desc="Multiprocessing context. Do not touch.",
196-
hint=FieldHint.expert,
197-
)
158+
@abc.abstractmethod
159+
def __len__(self):
160+
pass
161+
162+
@property
163+
@abc.abstractmethod
164+
def name(self):
165+
"""
166+
A name for the dataset to facilitate identification and debugging.
167+
"""
168+
169+
170+
class RawDataset(Dataset): # noqa
171+
"""
172+
A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk.
173+
(Excluding off-line processing prior to training.)
174+
Functionally identical to a `Dataset`, but renamed for clarity.
175+
"""
176+
177+
178+
class SampledDataset(Dataset): # noqa
179+
"""
180+
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
181+
(See the `Sampler` class below.)
182+
Functionally identical to a `Dataset`, but renamed for clarity.
183+
"""

0 commit comments

Comments
 (0)