Skip to content

Commit 31db203

Browse files
[MAINTENACE] Fixing method overload violations; providing typehints and method argument documentation (ludwig-ai#3753)
1 parent 8b423f1 commit 31db203

File tree

2 files changed

+46
-28
lines changed

2 files changed

+46
-28
lines changed

ludwig/datasets/loaders/dataset_loader.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from __future__ import annotations
16+
1517
import glob
1618
import hashlib
1719
import logging
1820
import os
1921
import shutil
2022
import urllib
2123
from enum import Enum
22-
from typing import Dict, List, Optional, Set, Tuple, Union
2324
from urllib.parse import urlparse
2425

2526
import pandas as pd
@@ -54,16 +55,16 @@ def update_to(self, b=1, bsize=1, tsize=None):
5455
Total size (in tqdm units). If [default: None] remains unchanged.
5556
"""
5657
if tsize is not None:
57-
self.total = tsize
58+
self.total = tsize # noqa W0201
5859
self.update(b * bsize - self.n) # will also set self.n = b * bsize
5960

6061

61-
def _list_of_strings(list_or_string: Union[str, List[str]]) -> List[str]:
62+
def _list_of_strings(list_or_string: str | list[str]) -> list[str]:
6263
"""Helper function to accept single string or lists in config."""
6364
return [list_or_string] if isinstance(list_or_string, str) else list_or_string
6465

6566

66-
def _glob_multiple(pathnames: List[str], root_dir: str = None, recursive: bool = True) -> Set[str]:
67+
def _glob_multiple(pathnames: list[str], root_dir: str = None, recursive: bool = True) -> set[str]:
6768
"""Recursive glob multiple patterns, returns set of matches.
6869
6970
Note: glob's root_dir argument was added in python 3.10, not using it for compatibility.
@@ -117,7 +118,7 @@ class DatasetLoader:
117118
training.
118119
"""
119120

120-
def __init__(self, config: DatasetConfig, cache_dir: Optional[str] = None):
121+
def __init__(self, config: DatasetConfig, cache_dir: str | None = None):
121122
"""Constructor."""
122123
self.config = config
123124
self.cache_dir = cache_dir if cache_dir else get_default_cache_location()
@@ -187,17 +188,18 @@ def state(self) -> DatasetState:
187188
return DatasetState.NOT_LOADED
188189

189190
@property
190-
def download_urls(self) -> List[str]:
191+
def download_urls(self) -> list[str]:
191192
return _list_of_strings(self.config.download_urls)
192193

193194
@property
194-
def download_filenames(self) -> List[str]:
195+
def download_filenames(self) -> list[str]:
195196
"""Filenames for downloaded files inferred from download_urls."""
196197
if self.config.archive_filenames:
197198
return _list_of_strings(self.config.archive_filenames)
198199
return [os.path.basename(urlparse(url).path) for url in self.download_urls]
199200

200-
def get_mirror_download_paths(self, mirror: DatasetFallbackMirror):
201+
@staticmethod
202+
def get_mirror_download_paths(mirror: DatasetFallbackMirror):
201203
"""Filenames for downloaded files inferred from mirror download_paths."""
202204
return _list_of_strings(mirror.download_paths)
203205

@@ -212,17 +214,17 @@ def description(self) -> str:
212214
return f"{self.config.name} {self.config.version}\n{self.config.description}"
213215

214216
@property
215-
def model_configs(self) -> Dict[str, Dict]:
217+
def model_configs(self) -> dict[str, dict]:
216218
"""Returns a dictionary of built-in model configs for this dataset."""
217219
return model_configs_for_dataset(self.config.name)
218220

219221
@property
220-
def best_model_config(self) -> Optional[Dict]:
222+
def best_model_config(self) -> dict | None:
221223
"""Returns the best built-in model config for this dataset, or None."""
222224
return self.model_configs.get("best")
223225

224226
@property
225-
def default_model_config(self) -> Optional[Dict]:
227+
def default_model_config(self) -> dict | None:
226228
"""Returns the default built-in model config for this dataset.
227229
228230
This is a good first model which should train in under 10m on a current laptop without GPU acceleration.
@@ -252,7 +254,7 @@ def export(self, output_directory: str) -> None:
252254
else:
253255
shutil.copy2(source, destination)
254256

255-
def _download_and_process(self, kaggle_username=None, kaggle_key=None):
257+
def _download_and_process(self, kaggle_username: str | None = None, kaggle_key: str | None = None):
256258
"""Loads the dataset, downloaded and processing it if needed.
257259
258260
If dataset is already processed, does nothing.
@@ -283,14 +285,18 @@ def _download_and_process(self, kaggle_username=None, kaggle_key=None):
283285
except Exception:
284286
logger.exception("Failed to transform dataset")
285287

286-
def load(self, split=False, kaggle_username=None, kaggle_key=None) -> pd.DataFrame:
288+
def load(
289+
self, kaggle_username: str | None = None, kaggle_key: str | None = None, split: bool = False
290+
) -> pd.DataFrame | list[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
287291
"""Loads the dataset, downloaded and processing it if needed.
288292
289293
Note: This method is also responsible for splitting the data, returning a single dataframe if split=False, and a
290294
3-tuple of train, val, test if split=True.
291295
296+
:param kaggle_username: (str) username on Kaggle platform
297+
:param kaggle_key: (str) dataset key on Kaggle platform
292298
:param split: (bool) splits dataset along 'split' column if present. The split column should always have values
293-
0: train, 1: validation, 2: test.
299+
0: train, 1: validation, 2: test.
294300
"""
295301
self._download_and_process(kaggle_username=kaggle_username, kaggle_key=kaggle_key)
296302
if self.state == DatasetState.TRANSFORMED:
@@ -300,7 +306,7 @@ def load(self, split=False, kaggle_username=None, kaggle_key=None) -> pd.DataFra
300306
else:
301307
return dataset_df
302308

303-
def download(self, kaggle_username=None, kaggle_key=None) -> List[str]:
309+
def download(self, kaggle_username: str | None = None, kaggle_key: str | None = None) -> list[str]:
304310
if not os.path.exists(self.raw_dataset_dir):
305311
os.makedirs(self.raw_dataset_dir)
306312
if self.is_kaggle_dataset:
@@ -347,7 +353,7 @@ def verify(self) -> None:
347353
digest = _sha256_digest(path)
348354
logger.info(f" {filename}: {digest}")
349355

350-
def extract(self) -> List[str]:
356+
def extract(self) -> list[str]:
351357
extracted_files = set()
352358
for download_filename in self.download_filenames:
353359
download_path = os.path.join(self.raw_dataset_dir, download_filename)
@@ -373,7 +379,7 @@ def transform(self) -> None:
373379
transformed_dataframe = self.transform_dataframe(unprocessed_dataframe)
374380
self.save_processed(transformed_dataframe)
375381

376-
def transform_files(self, file_paths: List[str]) -> List[str]:
382+
def transform_files(self, file_paths: list[str]) -> list[str]:
377383
"""Transform data files before loading to dataframe.
378384
379385
Subclasses should override this method to process files before loading dataframe, calling the base class
@@ -409,7 +415,7 @@ def load_file_to_dataframe(self, file_path: str) -> pd.DataFrame:
409415
else:
410416
raise ValueError(f"Unsupported dataset file type: {file_extension}")
411417

412-
def load_files_to_dataframe(self, file_paths: List[str], root_dir=None) -> pd.DataFrame:
418+
def load_files_to_dataframe(self, file_paths: list[str], root_dir=None) -> pd.DataFrame:
413419
"""Loads a file or list of files and returns a dataframe.
414420
415421
Subclasses may override this method to change the loader's behavior for groups of files.
@@ -439,7 +445,7 @@ def load_files_to_dataframe(self, file_paths: List[str], root_dir=None) -> pd.Da
439445
logger.warning(f"Error setting column names: {e}")
440446
return pd.concat(dataframes, ignore_index=True)
441447

442-
def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame:
448+
def load_unprocessed_dataframe(self, file_paths: list[str]) -> pd.DataFrame:
443449
"""Load dataset files into a dataframe.
444450
445451
Will use the list of data files in the dataset directory as a default if all of config's dataset_filenames,
@@ -451,7 +457,6 @@ def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame:
451457
_list_of_strings(self.config.validation_filenames), root_dir=self.raw_dataset_dir
452458
)
453459
test_paths = _glob_multiple(_list_of_strings(self.config.test_filenames), root_dir=self.raw_dataset_dir)
454-
dataframes = []
455460
if self.config.name == "hugging_face":
456461
dataframes = self._get_dataframe_with_fixed_splits_from_hf()
457462
else:
@@ -519,7 +524,8 @@ def get_mtime(self) -> float:
519524
"""Last modified time of the processed dataset after downloading successfully."""
520525
return os.path.getmtime(self.processed_dataset_path)
521526

522-
def split(self, dataset: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
527+
@staticmethod
528+
def split(dataset: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
523529
if SPLIT in dataset:
524530
dataset[SPLIT] = pd.to_numeric(dataset[SPLIT])
525531
training_set = dataset[dataset[SPLIT] == 0].drop(columns=[SPLIT])

ludwig/datasets/loaders/hugging_face.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
from __future__ import annotations
16+
1517
import logging
16-
from typing import Dict
1718

1819
import datasets
1920
import pandas as pd
@@ -33,29 +34,40 @@ class HFLoader(DatasetLoader):
3334
identify which dataset and which subsample of that dataset to load in.
3435
"""
3536

36-
def load_hf_to_dict(self, hf_id: str, hf_subsample: str) -> Dict[str, pd.DataFrame]:
37-
"""Returns a map of split -> pd.DataFrame for the given HF dataset."""
38-
dataset_dict: Dict[str, "datasets.arrow_dataset.Dataset"] = datasets.load_dataset(
39-
path=hf_id, name=hf_subsample
40-
) # noqa
37+
@staticmethod
38+
def load_hf_to_dict(hf_id: str | None = None, hf_subsample: str | None = None) -> dict[str, pd.DataFrame]:
39+
"""Returns a map of split -> pd.DataFrame for the given HF dataset.
40+
41+
:param hf_id: (str) path to dataset on HuggingFace platform
42+
:param hf_subsample: (str) name of dataset configuration on HuggingFace platform
43+
"""
44+
dataset_dict: dict[str, datasets.Dataset] = datasets.load_dataset(path=hf_id, name=hf_subsample)
4145
pandas_dict = {}
4246
for split in dataset_dict:
4347
# Convert from HF DatasetDict type to a dictionary of pandas dataframes
4448
pandas_dict[split] = dataset_dict[split].to_pandas()
4549
return pandas_dict
4650

47-
def load(self, hf_id, hf_subsample, split=False) -> pd.DataFrame:
51+
# TODO(Alex): Standardize load() signature as interface method in DatasetLoader and adhere to it in all subclasses.
52+
def load(
53+
self, hf_id: str | None = None, hf_subsample: str | None = None, split: bool = False
54+
) -> pd.DataFrame | list[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
4855
"""When load() is called, HFLoader calls the datasets API to return all of the data in a HuggingFace
4956
DatasetDict, converts it to a dictionary of pandas dataframes, and returns either three dataframes
5057
containing train, validation, and test data or one dataframe that is the concatenation of all three
5158
depending on whether `split` is set to True or False.
5259
60+
:param split: (bool) directive for how to interpret if dataset contains validation or test set (see below)
61+
5362
Note that some datasets may not provide a validation set or a test set. In this case:
5463
- If split is True, the DataFrames corresponding to the missing sets are initialized to be empty
5564
- If split is False, the "split" column in the resulting DataFrame will reflect the fact that there is no
5665
validation/test split (i.e., there will be no 1s/2s)
5766
5867
A train set should always be provided by Hugging Face.
68+
69+
:param hf_id: (str) path to dataset on HuggingFace platform
70+
:param hf_subsample: (str) name of dataset configuration on HuggingFace platform
5971
"""
6072
self.config.huggingface_dataset_id = hf_id
6173
self.config.huggingface_subsample = hf_subsample

0 commit comments

Comments
 (0)