Skip to content

Commit a47d569

Browse files
glemaitrethomasjpfanogriseladrinjalali
authored
ENH improve ARFF parser using pandas (scikit-learn#21938)
Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Adrin Jalali <[email protected]>
1 parent 9700cf1 commit a47d569

34 files changed

+1917
-1325
lines changed

asv_benchmarks/benchmarks/datasets.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def _20newsgroups_lowdim_dataset(n_components=100, ngrams=(1, 1), dtype=np.float
5959

6060
@M.cache
6161
def _mnist_dataset(dtype=np.float32):
62-
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
62+
X, y = fetch_openml(
63+
"mnist_784", version=1, return_X_y=True, as_frame=False, parser="pandas"
64+
)
6365
X = X.astype(dtype, copy=False)
6466
X = MaxAbsScaler().fit_transform(X)
6567

benchmarks/bench_hist_gradient_boosting_adult.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from time import time
33

44
import numpy as np
5+
import pandas as pd
56

67
from sklearn.model_selection import train_test_split
8+
from sklearn.compose import make_column_transformer, make_column_selector
79
from sklearn.datasets import fetch_openml
810
from sklearn.metrics import accuracy_score, roc_auc_score
911
from sklearn.ensemble import HistGradientBoostingClassifier
1012
from sklearn.ensemble._hist_gradient_boosting.utils import get_equivalent_estimator
13+
from sklearn.preprocessing import OrdinalEncoder
1114

1215

1316
parser = argparse.ArgumentParser()
@@ -47,22 +50,32 @@ def predict(est, data_test, target_test):
4750
print(f"predicted in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
4851

4952

50-
data = fetch_openml(data_id=179, as_frame=False) # adult dataset
53+
data = fetch_openml(data_id=179, as_frame=True, parser="pandas") # adult dataset
5154
X, y = data.data, data.target
5255

56+
# Ordinal encode the categories to use the native support available in HGBDT
57+
cat_columns = make_column_selector(dtype_include="category")(X)
58+
preprocessing = make_column_transformer(
59+
(OrdinalEncoder(), cat_columns),
60+
remainder="passthrough",
61+
verbose_feature_names_out=False,
62+
)
63+
X = pd.DataFrame(
64+
preprocessing.fit_transform(X),
65+
columns=preprocessing.get_feature_names_out(),
66+
)
67+
5368
n_classes = len(np.unique(y))
5469
n_features = X.shape[1]
55-
n_categorical_features = len(data.categories)
70+
n_categorical_features = len(cat_columns)
5671
n_numerical_features = n_features - n_categorical_features
5772
print(f"Number of features: {n_features}")
5873
print(f"Number of categorical features: {n_categorical_features}")
5974
print(f"Number of numerical features: {n_numerical_features}")
6075

6176
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
6277

63-
# Note: no need to use an OrdinalEncoder because categorical features are
64-
# already clean
65-
is_categorical = [name in data.categories for name in data.feature_names]
78+
is_categorical = [True] * n_categorical_features + [False] * n_numerical_features
6679
est = HistGradientBoostingClassifier(
6780
loss="log_loss",
6881
learning_rate=lr,

benchmarks/bench_isolation_forest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def print_outlier_ratio(y):
6464
y = dataset.target
6565

6666
if dat == "shuttle":
67-
dataset = fetch_openml("shuttle")
67+
dataset = fetch_openml("shuttle", as_frame=False, parser="pandas")
6868
X = dataset.data
69-
y = dataset.target
69+
y = dataset.target.astype(np.int64)
7070
X, y = sh(X, y, random_state=random_state)
7171
# we remove data with label 4
7272
# normal data are then those of class 1

benchmarks/bench_lof.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
y = dataset.target
4545

4646
if dataset_name == "shuttle":
47-
dataset = fetch_openml("shuttle")
47+
dataset = fetch_openml("shuttle", as_frame=False, parser="pandas")
4848
X = dataset.data
49-
y = dataset.target
49+
y = dataset.target.astype(np.int64)
5050
# we remove data with label 4
5151
# normal data are then those of class 1
5252
s = y != 4

benchmarks/bench_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def load_data(dtype=np.float32, order="F"):
6262
######################################################################
6363
# Load dataset
6464
print("Loading dataset...")
65-
data = fetch_openml("mnist_784")
65+
data = fetch_openml("mnist_784", as_frame=True, parser="pandas")
6666
X = check_array(data["data"], dtype=dtype, order=order)
6767
y = data["target"]
6868

benchmarks/bench_plot_randomized_svd.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def get_data(dataset_name):
191191
del row
192192
del col
193193
else:
194-
X = fetch_openml(dataset_name).data
194+
X = fetch_openml(dataset_name, parser="auto").data
195195
return X
196196

197197

@@ -281,9 +281,9 @@ def svd_timing(
281281
U, mu, V = randomized_svd(
282282
X,
283283
n_comps,
284-
n_oversamples,
285-
n_iter,
286-
power_iteration_normalizer,
284+
n_oversamples=n_oversamples,
285+
n_iter=n_iter,
286+
power_iteration_normalizer=power_iteration_normalizer,
287287
random_state=random_state,
288288
transpose=False,
289289
)

benchmarks/bench_tsne_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
def load_data(dtype=np.float32, order="C", shuffle=True, seed=0):
3636
"""Load the data, then cache and memmap the train/test split"""
3737
print("Loading dataset...")
38-
data = fetch_openml("mnist_784")
38+
data = fetch_openml("mnist_784", as_frame=True, parser="pandas")
3939

4040
X = check_array(data["data"], dtype=dtype, order=order)
4141
y = data["target"]

doc/datasets/loading_other_datasets.rst

+47-8
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ from the repository using the function
9999
For example, to download a dataset of gene expressions in mice brains::
100100

101101
>>> from sklearn.datasets import fetch_openml
102-
>>> mice = fetch_openml(name='miceprotein', version=4)
102+
>>> mice = fetch_openml(name='miceprotein', version=4, parser="auto")
103103

104104
To fully specify a dataset, you need to provide a name and a version, though
105105
the version is optional, see :ref:`openml_versions` below.
@@ -147,7 +147,7 @@ dataset on the openml website::
147147

148148
The ``data_id`` also uniquely identifies a dataset from OpenML::
149149

150-
>>> mice = fetch_openml(data_id=40966)
150+
>>> mice = fetch_openml(data_id=40966, parser="auto")
151151
>>> mice.details # doctest: +SKIP
152152
{'id': '4550', 'name': 'MiceProtein', 'version': '1', 'format': 'ARFF',
153153
'creator': ...,
@@ -171,8 +171,8 @@ which can contain entirely different datasets.
171171
If a particular version of a dataset has been found to contain significant
172172
issues, it might be deactivated. Using a name to specify a dataset will yield
173173
the earliest version of a dataset that is still active. That means that
174-
``fetch_openml(name="miceprotein")`` can yield different results at different
175-
times if earlier versions become inactive.
174+
``fetch_openml(name="miceprotein", parser="auto")`` can yield different results
175+
at different times if earlier versions become inactive.
176176
You can see that the dataset with ``data_id`` 40966 that we fetched above is
177177
the first version of the "miceprotein" dataset::
178178

@@ -182,19 +182,19 @@ the first version of the "miceprotein" dataset::
182182
In fact, this dataset only has one version. The iris dataset on the other hand
183183
has multiple versions::
184184

185-
>>> iris = fetch_openml(name="iris")
185+
>>> iris = fetch_openml(name="iris", parser="auto")
186186
>>> iris.details['version'] #doctest: +SKIP
187187
'1'
188188
>>> iris.details['id'] #doctest: +SKIP
189189
'61'
190190

191-
>>> iris_61 = fetch_openml(data_id=61)
191+
>>> iris_61 = fetch_openml(data_id=61, parser="auto")
192192
>>> iris_61.details['version']
193193
'1'
194194
>>> iris_61.details['id']
195195
'61'
196196

197-
>>> iris_969 = fetch_openml(data_id=969)
197+
>>> iris_969 = fetch_openml(data_id=969, parser="auto")
198198
>>> iris_969.details['version']
199199
'3'
200200
>>> iris_969.details['id']
@@ -212,7 +212,7 @@ binarized version of the data::
212212
You can also specify both the name and the version, which also uniquely
213213
identifies the dataset::
214214

215-
>>> iris_version_3 = fetch_openml(name="iris", version=3)
215+
>>> iris_version_3 = fetch_openml(name="iris", version=3, parser="auto")
216216
>>> iris_version_3.details['version']
217217
'3'
218218
>>> iris_version_3.details['id']
@@ -225,6 +225,45 @@ identifies the dataset::
225225
machine learning" ACM SIGKDD Explorations Newsletter, 15(2), 49-60, 2014.
226226
<1407.7722>`
227227

228+
.. _openml_parser:
229+
230+
ARFF parser
231+
~~~~~~~~~~~
232+
233+
From version 1.2, scikit-learn provides a new keyword argument `parser` that
234+
provides several options to parse the ARFF files provided by OpenML. The legacy
235+
parser (i.e. `parser="liac-arff"`) is based on the project
236+
`LIAC-ARFF <https://github.com/renatopp/liac-arff>`_. This parser is however
237+
slow and consume more memory than required. A new parser based on pandas
238+
(i.e. `parser="pandas"`) is both faster and more memory efficient.
239+
However, this parser does not support sparse data.
240+
Therefore, we recommend using `parser="auto"` which will use the best parser
241+
available for the requested dataset.
242+
243+
The `"pandas"` and `"liac-arff"` parsers can lead to different data types in
244+
the output. The notable differences are the following:
245+
246+
- The `"liac-arff"` parser always encodes categorical features as `str`
247+
objects. To the contrary, the `"pandas"` parser instead infers the type while
248+
reading and numerical categories will be casted into integers whenever
249+
possible.
250+
- The `"liac-arff"` parser uses float64 to encode numerical features tagged as
251+
'REAL' and 'NUMERICAL' in the metadata. The `"pandas"` parser instead infers
252+
if these numerical features corresponds to integers and uses panda's Integer
253+
extension dtype.
254+
- In particular, classification datasets with integer categories are typically
255+
loaded as such `(0, 1, ...)` with the `"pandas"` parser while `"liac-arff"`
256+
will force the use of string encoded class labels such as `"0"`, `"1"` and so
257+
on.
258+
259+
In addition, when `as_frame=False` is used, the `"liac-arff"` parser returns
260+
ordinally encoded data where the categories are provided in the attribute
261+
`categories` of the `Bunch` instance. Instead, `"pandas"` returns a NumPy array
262+
were the categories. Then it's up to the user to design a feature
263+
engineering pipeline with an instance of `OneHotEncoder` or
264+
`OrdinalEncoder` typically wrapped in a `ColumnTransformer` to
265+
preprocess the categorical columns explicitly. See for instance: :ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
266+
228267
.. _external_datasets:
229268

230269
Loading from external datasets

doc/whats_new/v1.1.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Changelog
2626
classifier that always predicts the positive class: recall=100% and
2727
precision=class balance.
2828
:pr:`23214` by :user:`Stéphane Collot <stephanecollot>` and :user:`Max Baak <mbaak>`.
29-
29+
3030
:mod:`sklearn.utils`
3131
....................
3232

@@ -208,7 +208,7 @@ Changelog
208208
:pr:`23194` by `Thomas Fan`_.
209209

210210
- |Enhancement| Added an extension in doc/conf.py to automatically generate
211-
the list of estimators that handle NaN values.
211+
the list of estimators that handle NaN values.
212212
:pr:`23198` by `Lise Kleiber <lisekleiber>`_, :user:`Zhehao Liu <MaxwellLZH>`
213213
and :user:`Chiara Marmo <cmarmo>`.
214214

doc/whats_new/v1.2.rst

+14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ Changelog
4444
- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32`
4545
inputs. :pr:`22968` by `Meekail Zain <micky774>`.
4646

47+
:mod:`sklearn.datasets`
48+
.......................
49+
50+
- |Enhancement| Introduce the new parameter `parser` in
51+
:func:`datasets.fetch_openml`. `parser="pandas"` allows to use the very CPU
52+
and memory efficient `pandas.read_csv` parser to load dense ARFF
53+
formatted dataset files. It is possible to pass `parser="liac-arff"`
54+
to use the old LIAC parser.
55+
When `parser="auto"`, dense datasets are loaded with "pandas" and sparse
56+
datasets are loaded with "liac-arff".
57+
Currently, `parser="liac-arff"` by default and will change to `parser="auto"`
58+
in version 1.4
59+
:pr:`21938` by :user:`Guillaume Lemaitre <glemaitre>`.
60+
4761
:mod:`sklearn.ensemble`
4862
.......................
4963

examples/applications/plot_cyclical_feature_engineering.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
# We start by loading the data from the OpenML repository.
2121
from sklearn.datasets import fetch_openml
2222

23-
bike_sharing = fetch_openml("Bike_Sharing_Demand", version=2, as_frame=True)
23+
bike_sharing = fetch_openml(
24+
"Bike_Sharing_Demand", version=2, as_frame=True, parser="pandas"
25+
)
2426
df = bike_sharing.frame
2527

2628
# %%

examples/applications/plot_digits_denoising.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from sklearn.preprocessing import MinMaxScaler
3737
from sklearn.model_selection import train_test_split
3838

39-
X, y = fetch_openml(data_id=41082, as_frame=False, return_X_y=True)
39+
X, y = fetch_openml(data_id=41082, as_frame=False, return_X_y=True, parser="pandas")
4040
X = MinMaxScaler().fit_transform(X)
4141

4242
# %%

examples/compose/plot_column_transformer_mixed_types.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343

4444
# %%
4545
# Load data from https://www.openml.org/d/40945
46-
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
46+
X, y = fetch_openml(
47+
"titanic", version=1, as_frame=True, return_X_y=True, parser="pandas"
48+
)
4749

4850
# Alternatively X and y can be obtained directly from the frame attribute:
4951
# X = titanic.frame.drop('survived', axis=1)

examples/compose/plot_transformed_target.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
from sklearn.datasets import fetch_openml
129129
from sklearn.preprocessing import QuantileTransformer, quantile_transform
130130

131-
ames = fetch_openml(name="house_prices", as_frame=True)
131+
ames = fetch_openml(name="house_prices", as_frame=True, parser="pandas")
132132
# Keep only numeric columns
133133
X = ames.data.select_dtypes(np.number)
134134
# Remove columns with NaN or Inf values

examples/ensemble/plot_gradient_boosting_categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# are either categorical or numerical:
3131
from sklearn.datasets import fetch_openml
3232

33-
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
33+
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True, parser="pandas")
3434

3535
# Select only a subset of features of X to make the example faster to run
3636
categorical_columns_subset = [

examples/ensemble/plot_stack_predictors.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
def load_ames_housing():
48-
df = fetch_openml(name="house_prices", as_frame=True)
48+
df = fetch_openml(name="house_prices", as_frame=True, parser="pandas")
4949
X = df.data
5050
y = df.target
5151

@@ -117,7 +117,9 @@ def load_ames_housing():
117117
from sklearn.preprocessing import OrdinalEncoder
118118

119119
cat_tree_processor = OrdinalEncoder(
120-
handle_unknown="use_encoded_value", unknown_value=-1
120+
handle_unknown="use_encoded_value",
121+
unknown_value=-1,
122+
encoded_missing_value=-2,
121123
)
122124
num_tree_processor = SimpleImputer(strategy="mean", add_indicator=True)
123125

examples/gaussian_process/plot_gpr_co2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# in OpenML.
3737
from sklearn.datasets import fetch_openml
3838

39-
co2 = fetch_openml(data_id=41187, as_frame=True)
39+
co2 = fetch_openml(data_id=41187, as_frame=True, parser="pandas")
4040
co2.frame.head()
4141

4242
# %%

examples/inspection/plot_linear_model_coefficient_interpretation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from sklearn.datasets import fetch_openml
4848

49-
survey = fetch_openml(data_id=534, as_frame=True)
49+
survey = fetch_openml(data_id=534, as_frame=True, parser="pandas")
5050

5151
# %%
5252
# Then, we identify features `X` and targets `y`: the column WAGE is our

examples/inspection/plot_permutation_importance.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
from sklearn.datasets import fetch_openml
4444
from sklearn.model_selection import train_test_split
4545

46-
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
46+
X, y = fetch_openml(
47+
"titanic", version=1, as_frame=True, return_X_y=True, parser="pandas"
48+
)
4749
rng = np.random.RandomState(seed=42)
4850
X["random_cat"] = rng.randint(3, size=X.shape[0])
4951
X["random_num"] = rng.randn(X.shape[0])

examples/linear_model/plot_poisson_regression_non_normal_loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from sklearn.datasets import fetch_openml
5757

5858

59-
df = fetch_openml(data_id=41214, as_frame=True).frame
59+
df = fetch_openml(data_id=41214, as_frame=True, parser="pandas").frame
6060
df
6161

6262
# %%

examples/linear_model/plot_sgd_early_stopping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
def load_mnist(n_samples=None, class_0="0", class_1="8"):
6060
"""Load MNIST, select two classes, shuffle and return only n_samples."""
6161
# Load data from http://openml.org/d/554
62-
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
62+
mnist = fetch_openml("mnist_784", version=1, as_frame=False, parser="pandas")
6363

6464
# take only two classes for binary classification
6565
mask = np.logical_or(mnist.target == class_0, mnist.target == class_1)

0 commit comments

Comments
 (0)