12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
# ==============================================================================
15
+ from __future__ import annotations
16
+
15
17
import glob
16
18
import hashlib
17
19
import logging
18
20
import os
19
21
import shutil
20
22
import urllib
21
23
from enum import Enum
22
- from typing import Dict , List , Optional , Set , Tuple , Union
23
24
from urllib .parse import urlparse
24
25
25
26
import pandas as pd
@@ -54,16 +55,16 @@ def update_to(self, b=1, bsize=1, tsize=None):
54
55
Total size (in tqdm units). If [default: None] remains unchanged.
55
56
"""
56
57
if tsize is not None :
57
- self .total = tsize
58
+ self .total = tsize # noqa W0201
58
59
self .update (b * bsize - self .n ) # will also set self.n = b * bsize
59
60
60
61
61
- def _list_of_strings (list_or_string : Union [ str , List [str ]] ) -> List [str ]:
62
+ def _list_of_strings (list_or_string : str | list [str ]) -> list [str ]:
62
63
"""Helper function to accept single string or lists in config."""
63
64
return [list_or_string ] if isinstance (list_or_string , str ) else list_or_string
64
65
65
66
66
- def _glob_multiple (pathnames : List [str ], root_dir : str = None , recursive : bool = True ) -> Set [str ]:
67
+ def _glob_multiple (pathnames : list [str ], root_dir : str = None , recursive : bool = True ) -> set [str ]:
67
68
"""Recursive glob multiple patterns, returns set of matches.
68
69
69
70
Note: glob's root_dir argument was added in python 3.10, not using it for compatibility.
@@ -117,7 +118,7 @@ class DatasetLoader:
117
118
training.
118
119
"""
119
120
120
- def __init__ (self , config : DatasetConfig , cache_dir : Optional [ str ] = None ):
121
+ def __init__ (self , config : DatasetConfig , cache_dir : str | None = None ):
121
122
"""Constructor."""
122
123
self .config = config
123
124
self .cache_dir = cache_dir if cache_dir else get_default_cache_location ()
@@ -187,17 +188,18 @@ def state(self) -> DatasetState:
187
188
return DatasetState .NOT_LOADED
188
189
189
190
@property
190
- def download_urls (self ) -> List [str ]:
191
+ def download_urls (self ) -> list [str ]:
191
192
return _list_of_strings (self .config .download_urls )
192
193
193
194
@property
194
- def download_filenames (self ) -> List [str ]:
195
+ def download_filenames (self ) -> list [str ]:
195
196
"""Filenames for downloaded files inferred from download_urls."""
196
197
if self .config .archive_filenames :
197
198
return _list_of_strings (self .config .archive_filenames )
198
199
return [os .path .basename (urlparse (url ).path ) for url in self .download_urls ]
199
200
200
- def get_mirror_download_paths (self , mirror : DatasetFallbackMirror ):
201
+ @staticmethod
202
+ def get_mirror_download_paths (mirror : DatasetFallbackMirror ):
201
203
"""Filenames for downloaded files inferred from mirror download_paths."""
202
204
return _list_of_strings (mirror .download_paths )
203
205
@@ -212,17 +214,17 @@ def description(self) -> str:
212
214
return f"{ self .config .name } { self .config .version } \n { self .config .description } "
213
215
214
216
@property
215
- def model_configs (self ) -> Dict [str , Dict ]:
217
+ def model_configs (self ) -> dict [str , dict ]:
216
218
"""Returns a dictionary of built-in model configs for this dataset."""
217
219
return model_configs_for_dataset (self .config .name )
218
220
219
221
@property
220
- def best_model_config (self ) -> Optional [ Dict ] :
222
+ def best_model_config (self ) -> dict | None :
221
223
"""Returns the best built-in model config for this dataset, or None."""
222
224
return self .model_configs .get ("best" )
223
225
224
226
@property
225
- def default_model_config (self ) -> Optional [ Dict ] :
227
+ def default_model_config (self ) -> dict | None :
226
228
"""Returns the default built-in model config for this dataset.
227
229
228
230
This is a good first model which should train in under 10m on a current laptop without GPU acceleration.
@@ -252,7 +254,7 @@ def export(self, output_directory: str) -> None:
252
254
else :
253
255
shutil .copy2 (source , destination )
254
256
255
- def _download_and_process (self , kaggle_username = None , kaggle_key = None ):
257
+ def _download_and_process (self , kaggle_username : str | None = None , kaggle_key : str | None = None ):
256
258
"""Loads the dataset, downloaded and processing it if needed.
257
259
258
260
If dataset is already processed, does nothing.
@@ -283,14 +285,18 @@ def _download_and_process(self, kaggle_username=None, kaggle_key=None):
283
285
except Exception :
284
286
logger .exception ("Failed to transform dataset" )
285
287
286
- def load (self , split = False , kaggle_username = None , kaggle_key = None ) -> pd .DataFrame :
288
+ def load (
289
+ self , kaggle_username : str | None = None , kaggle_key : str | None = None , split : bool = False
290
+ ) -> pd .DataFrame | list [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
287
291
"""Loads the dataset, downloaded and processing it if needed.
288
292
289
293
Note: This method is also responsible for splitting the data, returning a single dataframe if split=False, and a
290
294
3-tuple of train, val, test if split=True.
291
295
296
+ :param kaggle_username: (str) username on Kaggle platform
297
+ :param kaggle_key: (str) dataset key on Kaggle platform
292
298
:param split: (bool) splits dataset along 'split' column if present. The split column should always have values
293
- 0: train, 1: validation, 2: test.
299
+ 0: train, 1: validation, 2: test.
294
300
"""
295
301
self ._download_and_process (kaggle_username = kaggle_username , kaggle_key = kaggle_key )
296
302
if self .state == DatasetState .TRANSFORMED :
@@ -300,7 +306,7 @@ def load(self, split=False, kaggle_username=None, kaggle_key=None) -> pd.DataFra
300
306
else :
301
307
return dataset_df
302
308
303
- def download (self , kaggle_username = None , kaggle_key = None ) -> List [str ]:
309
+ def download (self , kaggle_username : str | None = None , kaggle_key : str | None = None ) -> list [str ]:
304
310
if not os .path .exists (self .raw_dataset_dir ):
305
311
os .makedirs (self .raw_dataset_dir )
306
312
if self .is_kaggle_dataset :
@@ -347,7 +353,7 @@ def verify(self) -> None:
347
353
digest = _sha256_digest (path )
348
354
logger .info (f" { filename } : { digest } " )
349
355
350
- def extract (self ) -> List [str ]:
356
+ def extract (self ) -> list [str ]:
351
357
extracted_files = set ()
352
358
for download_filename in self .download_filenames :
353
359
download_path = os .path .join (self .raw_dataset_dir , download_filename )
@@ -373,7 +379,7 @@ def transform(self) -> None:
373
379
transformed_dataframe = self .transform_dataframe (unprocessed_dataframe )
374
380
self .save_processed (transformed_dataframe )
375
381
376
- def transform_files (self , file_paths : List [str ]) -> List [str ]:
382
+ def transform_files (self , file_paths : list [str ]) -> list [str ]:
377
383
"""Transform data files before loading to dataframe.
378
384
379
385
Subclasses should override this method to process files before loading dataframe, calling the base class
@@ -409,7 +415,7 @@ def load_file_to_dataframe(self, file_path: str) -> pd.DataFrame:
409
415
else :
410
416
raise ValueError (f"Unsupported dataset file type: { file_extension } " )
411
417
412
- def load_files_to_dataframe (self , file_paths : List [str ], root_dir = None ) -> pd .DataFrame :
418
+ def load_files_to_dataframe (self , file_paths : list [str ], root_dir = None ) -> pd .DataFrame :
413
419
"""Loads a file or list of files and returns a dataframe.
414
420
415
421
Subclasses may override this method to change the loader's behavior for groups of files.
@@ -439,7 +445,7 @@ def load_files_to_dataframe(self, file_paths: List[str], root_dir=None) -> pd.Da
439
445
logger .warning (f"Error setting column names: { e } " )
440
446
return pd .concat (dataframes , ignore_index = True )
441
447
442
- def load_unprocessed_dataframe (self , file_paths : List [str ]) -> pd .DataFrame :
448
+ def load_unprocessed_dataframe (self , file_paths : list [str ]) -> pd .DataFrame :
443
449
"""Load dataset files into a dataframe.
444
450
445
451
Will use the list of data files in the dataset directory as a default if all of config's dataset_filenames,
@@ -451,7 +457,6 @@ def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame:
451
457
_list_of_strings (self .config .validation_filenames ), root_dir = self .raw_dataset_dir
452
458
)
453
459
test_paths = _glob_multiple (_list_of_strings (self .config .test_filenames ), root_dir = self .raw_dataset_dir )
454
- dataframes = []
455
460
if self .config .name == "hugging_face" :
456
461
dataframes = self ._get_dataframe_with_fixed_splits_from_hf ()
457
462
else :
@@ -519,7 +524,8 @@ def get_mtime(self) -> float:
519
524
"""Last modified time of the processed dataset after downloading successfully."""
520
525
return os .path .getmtime (self .processed_dataset_path )
521
526
522
- def split (self , dataset : pd .DataFrame ) -> Tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
527
+ @staticmethod
528
+ def split (dataset : pd .DataFrame ) -> tuple [pd .DataFrame , pd .DataFrame , pd .DataFrame ]:
523
529
if SPLIT in dataset :
524
530
dataset [SPLIT ] = pd .to_numeric (dataset [SPLIT ])
525
531
training_set = dataset [dataset [SPLIT ] == 0 ].drop (columns = [SPLIT ])
0 commit comments