1
- import abc
2
1
import logging
3
2
import pathlib
4
3
import time
5
4
6
5
import numpy as np
7
- import torch .utils .data
8
6
9
7
from fast_llm .core .distributed import ProcessGroup , safe_barrier
8
+ from fast_llm .data .config import SampledDataset
10
9
from fast_llm .engine .config_utils .run import log_main_rank
11
10
from fast_llm .utils import Assert
12
11
20
19
logger = logging .getLogger (__name__ )
21
20
22
21
23
- class Dataset (abc .ABC ):
24
- """
25
- A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
26
- """
27
-
28
- @abc .abstractmethod
29
- def __getitem__ (self , index : int ):
30
- pass
31
-
32
- @abc .abstractmethod
33
- def __len__ (self ):
34
- pass
35
-
36
- @property
37
- @abc .abstractmethod
38
- def name (self ):
39
- """
40
- A name for the dataset to facilitate identification and debugging.
41
- """
42
-
43
-
44
- class RawDataset (Dataset ): # noqa
45
- """
46
- A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk.
47
- (Excluding off-line processing prior to training.)
48
- Functionally identical to a `Dataset`, but renamed for clarity.
49
- """
50
-
51
-
52
- class SampledDataset (Dataset ): # noqa
53
- """
54
- A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
55
- (See the `Sampler` class below.)
56
- Functionally identical to a `Dataset`, but renamed for clarity.
57
- """
58
-
59
-
60
22
class BlendedDataset (SampledDataset ):
61
23
"""
62
24
A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability.
@@ -72,7 +34,7 @@ def __init__(
72
34
* ,
73
35
name : str = "blended" ,
74
36
num_samples : int ,
75
- cache_dir : pathlib .Path | None = None ,
37
+ cache_directory : pathlib .Path | None = None ,
76
38
group : ProcessGroup | None = None ,
77
39
verbose : bool = True ,
78
40
data_sample_warn_time_ms : float = 1000 ,
@@ -83,19 +45,20 @@ def __init__(
83
45
self ._weights = weights
84
46
self ._data_sample_warn_time_ms = data_sample_warn_time_ms
85
47
86
- if cache_dir is None :
48
+ if cache_directory is None :
87
49
self ._dataset_idx_filename , self ._sample_idx_filename = None , None
88
50
self ._dataset_index , self ._sample_index = self ._build_blending_indices (verbose and len (datasets ) <= 20 )
89
51
else :
90
- self ._dataset_idx_filename = cache_dir / (self ._name + "_blending_dataset_idx.npy" )
91
- self ._sample_idx_filename = cache_dir / (self ._name + "_blending_sample_idx.npy" )
52
+ self ._dataset_idx_filename = cache_directory / (self ._name + "_blending_dataset_idx.npy" )
53
+ self ._sample_idx_filename = cache_directory / (self ._name + "_blending_sample_idx.npy" )
92
54
93
55
# Build the indexed mapping if it doesn't exist.
94
56
# TODO: This only works if the dataset location is accessible by all job.
95
57
if (group is None or group .rank () == 0 ) and not (
96
58
self ._dataset_idx_filename .is_file () and self ._sample_idx_filename .is_file ()
97
59
):
98
60
dataset_index , sample_index = self ._build_blending_indices (verbose and len (datasets ) <= 20 )
61
+ cache_directory .mkdir (exist_ok = True , parents = True )
99
62
np .save (self ._dataset_idx_filename , dataset_index )
100
63
np .save (self ._sample_idx_filename , sample_index )
101
64
@@ -140,7 +103,9 @@ def __len__(self):
140
103
return self ._num_samples
141
104
142
105
def _build_blending_indices (self , verbose : bool ):
143
- assert _extension_available , "Please run `make -C ./fast_llm/csrc/` first."
106
+ assert _extension_available , (
107
+ "The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
108
+ )
144
109
Assert .lt (len (self ._datasets ), 32767 )
145
110
dataset_index = np .zeros (self ._num_samples , dtype = np .int16 )
146
111
dataset_sample_index = np .zeros (self ._num_samples , dtype = np .int64 )
@@ -191,24 +156,3 @@ def __getitem__(self, idx):
191
156
@property
192
157
def name (self ):
193
158
return self ._name
194
-
195
-
196
- class Sampler (torch .utils .data .Sampler ):
197
- """
198
- A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
199
- To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
200
- """
201
-
202
- def __init__ (self , total_samples , begin_index , micro_batch_size , data_rank , data_parallel ):
203
- self ._total_samples = total_samples
204
- self ._begin_index = begin_index
205
- self ._batch_size = micro_batch_size * data_parallel
206
- self ._start_idx = data_rank * micro_batch_size
207
- self ._end_idx = (data_rank + 1 ) * micro_batch_size
208
-
209
- def __len__ (self ):
210
- return self ._total_samples
211
-
212
- def __iter__ (self ):
213
- for idx in range (self ._begin_index , self ._total_samples - self ._batch_size + 1 , self ._batch_size ):
214
- yield list (range (idx + self ._start_idx , idx + self ._end_idx ))
0 commit comments