From 51a55c2ebac031c9d0e1a6aaa77326eaf0f17608 Mon Sep 17 00:00:00 2001 From: Johann Faouzi Date: Wed, 26 May 2021 13:15:03 +0200 Subject: [PATCH] Replace scikit-learn mixin classes with pyts mixin classes (#100) --- pyts/approximation/dft.py | 5 +- pyts/approximation/mcb.py | 5 +- pyts/approximation/paa.py | 6 +- pyts/approximation/sax.py | 6 +- pyts/approximation/sfa.py | 5 +- pyts/bag_of_words/bow.py | 3 +- pyts/base.py | 128 ++++++++++++++++++ pyts/classification/bossvs.py | 5 +- pyts/classification/knn.py | 5 +- pyts/classification/learning_shapelets.py | 11 +- pyts/classification/saxvsm.py | 5 +- pyts/classification/time_series_forest.py | 9 +- pyts/classification/tsbf.py | 12 +- pyts/decomposition/ssa.py | 5 +- pyts/image/gaf.py | 5 +- pyts/image/mtf.py | 5 +- pyts/image/recurrence.py | 5 +- .../classification/multivariate.py | 5 +- pyts/multivariate/image/joint_rp.py | 5 +- .../transformation/multivariate.py | 5 +- .../transformation/weasel_muse.py | 5 +- pyts/preprocessing/discretizer.py | 5 +- pyts/preprocessing/imputer.py | 5 +- pyts/preprocessing/scaler.py | 3 +- pyts/preprocessing/transformer.py | 7 +- pyts/tests/test_base.py | 65 +++++++++ pyts/transformation/bag_of_patterns.py | 5 +- pyts/transformation/boss.py | 5 +- pyts/transformation/rocket.py | 11 +- pyts/transformation/shapelet_transform.py | 5 +- pyts/transformation/weasel.py | 5 +- 31 files changed, 292 insertions(+), 69 deletions(-) create mode 100644 pyts/base.py create mode 100644 pyts/tests/test_base.py diff --git a/pyts/approximation/dft.py b/pyts/approximation/dft.py index 449c41f..8197604 100644 --- a/pyts/approximation/dft.py +++ b/pyts/approximation/dft.py @@ -4,15 +4,16 @@ # License: BSD-3-Clause import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.feature_selection import f_classif from sklearn.utils.validation import check_array, check_is_fitted, check_X_y from math import ceil from warnings import warn +from ..base import UnivariateTransformerMixin from ..preprocessing import StandardScaler -class DiscreteFourierTransform(BaseEstimator, TransformerMixin): +class DiscreteFourierTransform(BaseEstimator, UnivariateTransformerMixin): """Discrete Fourier Transform. Parameters diff --git a/pyts/approximation/mcb.py b/pyts/approximation/mcb.py index d08468e..64aea10 100644 --- a/pyts/approximation/mcb.py +++ b/pyts/approximation/mcb.py @@ -6,7 +6,8 @@ import numpy as np from numba import njit, prange from scipy.stats import norm -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from ..base import UnivariateTransformerMixin from sklearn.tree import DecisionTreeClassifier from sklearn.utils.validation import check_array, check_is_fitted, check_X_y from sklearn.utils.multiclass import check_classification_targets @@ -47,7 +48,7 @@ def _digitize(X, bins): return X_binned.astype('int64') -class MultipleCoefficientBinning(BaseEstimator, TransformerMixin): +class MultipleCoefficientBinning(BaseEstimator, UnivariateTransformerMixin): """Bin continuous data into intervals column-wise. Parameters diff --git a/pyts/approximation/paa.py b/pyts/approximation/paa.py index d5b122f..f0340a7 100644 --- a/pyts/approximation/paa.py +++ b/pyts/approximation/paa.py @@ -6,8 +6,9 @@ import numpy as np from math import ceil from numba import njit, prange -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin from ..utils import segmentation @@ -20,7 +21,8 @@ def _paa(X, n_samples, n_timestamps, start, end, n_timestamps_new): return X_paa -class PiecewiseAggregateApproximation(BaseEstimator, TransformerMixin): +class PiecewiseAggregateApproximation(BaseEstimator, + UnivariateTransformerMixin): """Piecewise Aggregate Approximation. Parameters diff --git a/pyts/approximation/sax.py b/pyts/approximation/sax.py index 92c3b49..18ae229 100644 --- a/pyts/approximation/sax.py +++ b/pyts/approximation/sax.py @@ -4,12 +4,14 @@ # License: BSD-3-Clause import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin from ..preprocessing import KBinsDiscretizer -class SymbolicAggregateApproximation(BaseEstimator, TransformerMixin): +class SymbolicAggregateApproximation(BaseEstimator, + UnivariateTransformerMixin): """Symbolic Aggregate approXimation. Parameters diff --git a/pyts/approximation/sfa.py b/pyts/approximation/sfa.py index 360a358..4f1b32c 100644 --- a/pyts/approximation/sfa.py +++ b/pyts/approximation/sfa.py @@ -3,14 +3,15 @@ # Author: Johann Faouzi # License: BSD-3-Clause -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.pipeline import Pipeline from sklearn.utils.validation import check_is_fitted from .dft import DiscreteFourierTransform from .mcb import MultipleCoefficientBinning +from ..base import UnivariateTransformerMixin -class SymbolicFourierApproximation(BaseEstimator, TransformerMixin): +class SymbolicFourierApproximation(BaseEstimator, UnivariateTransformerMixin): """Symbolic Fourier Approximation. Parameters diff --git a/pyts/bag_of_words/bow.py b/pyts/bag_of_words/bow.py index eaee905..6085f75 100644 --- a/pyts/bag_of_words/bow.py +++ b/pyts/bag_of_words/bow.py @@ -11,11 +11,12 @@ import warnings from ..approximation import ( PiecewiseAggregateApproximation, SymbolicAggregateApproximation) +from ..base import UnivariateTransformerMixin from ..preprocessing import StandardScaler from ..utils.utils import _windowed_view -class WordExtractor(BaseEstimator, TransformerMixin): +class WordExtractor(BaseEstimator, UnivariateTransformerMixin): r"""Transform discretized time series into sequences of words. Parameters diff --git a/pyts/base.py b/pyts/base.py new file mode 100644 index 0000000..d2a17bd --- /dev/null +++ b/pyts/base.py @@ -0,0 +1,128 @@ +"""Base classes for all estimators.""" + +# Author: Johann Faouzi +# License: BSD-3-Clause + +from sklearn.metrics import accuracy_score + + +class UnivariateTransformerMixin: + """Mixin class for all univariate transformers in pyts.""" + + def fit_transform(self, X, y=None, **fit_params): + """Fit to data, then transform it. + + Fits transformer to `X` and `y` with optional parameters `fit_params` + and returns a transformed version of `X`. + + Parameters + ---------- + X : array-like, shape = (n_samples, n_timestamps) + Univariate time series. + + y : None or array-like, shape = (n_samples,) (default = None) + Target values (None for unsupervised transformations). + + **fit_params : dict + Additional fit parameters. + + Returns + ------- + X_new : array + Transformed array. + + """ # noqa: E501 + if y is None: + # fit method of arity 1 (unsupervised transformation) + return self.fit(X, **fit_params).transform(X) + else: + # fit method of arity 2 (supervised transformation) + return self.fit(X, y, **fit_params).transform(X) + + +class MultivariateTransformerMixin: + """Mixin class for all multivariate transformers in pyts.""" + + def fit_transform(self, X, y=None, **fit_params): + """Fit to data, then transform it. + + Fits transformer to `X` and `y` with optional parameters `fit_params` + and returns a transformed version of `X`. + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features, n_timestamps) + Multivariate time series. + + y : None or array-like, shape = (n_samples,) (default = None) + Target values (None for unsupervised transformations). + + **fit_params : dict + Additional fit parameters. + + Returns + ------- + X_new : array + Transformed array. + + """ # noqa: E501 + if y is None: + # fit method of arity 1 (unsupervised transformation) + return self.fit(X, **fit_params).transform(X) + else: + # fit method of arity 2 (supervised transformation) + return self.fit(X, y, **fit_params).transform(X) + + +class UnivariateClassifierMixin: + """Mixin class for all univariate classifiers in pyts.""" + + def score(self, X, y, sample_weight=None): + """ + Return the mean accuracy on the given test data and labels. + + Parameters + ---------- + X : array-like, shape = (n_samples, n_timestamps) + Univariate time series. + + y : array-like, shape = (n_samples,) + True labels for `X`. + + sample_weight : None or array-like, shape = (n_samples,) (default = None) + Sample weights. + + Returns + ------- + score : float + Mean accuracy of ``self.predict(X)`` with regards to `y`. + + """ # noqa: E501 + return accuracy_score(y, self.predict(X), sample_weight=sample_weight) + + +class MultivariateClassifierMixin: + """Mixin class for all multivariate classifiers in pyts.""" + + def score(self, X, y, sample_weight=None): + """ + Return the mean accuracy on the given test data and labels. + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features, n_timestamps) + Multivariate time series. + + y : array-like, shape = (n_samples,) + True labels for `X`. + + sample_weight : None or array-like, shape = (n_samples,) (default = None) + Sample weights. + + Returns + ------- + score : float + Mean accuracy of ``self.predict(X)`` with regards to `y`. + + """ # noqa: E501 + return accuracy_score(y, self.predict(X), sample_weight=sample_weight) diff --git a/pyts/classification/bossvs.py b/pyts/classification/bossvs.py index bcb807a..3ec796f 100644 --- a/pyts/classification/bossvs.py +++ b/pyts/classification/bossvs.py @@ -7,15 +7,16 @@ from math import ceil from sklearn.utils.validation import check_array, check_X_y, check_is_fitted from sklearn.utils.multiclass import check_classification_targets -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import LabelEncoder from sklearn.feature_extraction.text import TfidfVectorizer from ..approximation import SymbolicFourierApproximation +from ..base import UnivariateClassifierMixin from ..utils.utils import _windowed_view -class BOSSVS(BaseEstimator, ClassifierMixin): +class BOSSVS(BaseEstimator, UnivariateClassifierMixin): """Bag-of-SFA Symbols in Vector Space. Each time series is transformed into an histogram using the diff --git a/pyts/classification/knn.py b/pyts/classification/knn.py index dcb059a..8547861 100644 --- a/pyts/classification/knn.py +++ b/pyts/classification/knn.py @@ -3,16 +3,17 @@ # Author: Johann Faouzi # License: BSD-3-Clause -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.neighbors import KNeighborsClassifier as SklearnKNN from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_X_y, check_is_fitted +from ..base import UnivariateClassifierMixin from ..metrics import boss, dtw, sakoe_chiba_band, itakura_parallelogram from ..metrics.dtw import (_dtw_classic, _dtw_region, _dtw_fast, _dtw_multiscale) -class KNeighborsClassifier(BaseEstimator, ClassifierMixin): +class KNeighborsClassifier(BaseEstimator, UnivariateClassifierMixin): """k-nearest neighbors classifier. Parameters diff --git a/pyts/classification/learning_shapelets.py b/pyts/classification/learning_shapelets.py index f648108..d3aa9d4 100644 --- a/pyts/classification/learning_shapelets.py +++ b/pyts/classification/learning_shapelets.py @@ -5,11 +5,9 @@ from itertools import chain from math import ceil - from numba import njit, prange import numpy as np - -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.cluster import KMeans from sklearn.exceptions import ConvergenceWarning from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier @@ -19,10 +17,9 @@ from sklearn.utils.validation import ( check_is_fitted, check_random_state, check_X_y, _check_sample_weight) from sklearn.utils.multiclass import check_classification_targets - import warnings - from ..utils.utils import _windowed_view +from ..base import UnivariateClassifierMixin @njit(fastmath=True) @@ -305,7 +302,7 @@ def _grad_shapelets(X, y, n_classes, weights, shapelets, lengths, alpha, return gradients -class CrossEntropyLearningShapelets(BaseEstimator, ClassifierMixin): +class CrossEntropyLearningShapelets(BaseEstimator, UnivariateClassifierMixin): """Learning Shapelets algorithm with cross-entropy loss. Parameters @@ -811,7 +808,7 @@ def _check_params(self, X, y, y_ind, classes, sample_weight): return n_shapelets_per_size, min_shapelet_length, sample_weight, rng -class LearningShapelets(BaseEstimator, ClassifierMixin): +class LearningShapelets(BaseEstimator, UnivariateClassifierMixin): """Learning Shapelets algorithm. This estimator consists of two steps: computing the distances between the diff --git a/pyts/classification/saxvsm.py b/pyts/classification/saxvsm.py index 7f133c6..ae0a799 100644 --- a/pyts/classification/saxvsm.py +++ b/pyts/classification/saxvsm.py @@ -5,14 +5,15 @@ from sklearn.utils.validation import check_X_y, check_array, check_is_fitted from sklearn.utils.multiclass import check_classification_targets -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import LabelEncoder from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer from ..bag_of_words import BagOfWords +from ..base import UnivariateClassifierMixin -class SAXVSM(BaseEstimator, ClassifierMixin): +class SAXVSM(BaseEstimator, UnivariateClassifierMixin): """Classifier based on SAX-VSM representation and tf-idf statistics. Time series are first transformed into bag of words using Symbolic diff --git a/pyts/classification/time_series_forest.py b/pyts/classification/time_series_forest.py index 1275331..ec53b23 100644 --- a/pyts/classification/time_series_forest.py +++ b/pyts/classification/time_series_forest.py @@ -6,11 +6,12 @@ from math import ceil from numba import njit import numpy as np -from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestClassifier from sklearn.pipeline import Pipeline from sklearn.utils.validation import ( check_array, check_is_fitted, check_random_state) +from ..base import UnivariateClassifierMixin, UnivariateTransformerMixin @njit() @@ -35,7 +36,7 @@ def extract_features(X, n_samples, n_windows, indices): return X_new -class WindowFeatureExtractor(BaseEstimator, TransformerMixin): +class WindowFeatureExtractor(BaseEstimator, UnivariateTransformerMixin): """Feature extractor over a window. This transformer extracts 3 features from each window: the mean, the @@ -174,7 +175,7 @@ def _check_params(self, X): return n_windows, min_window_size, rng -class TimeSeriesForest(BaseEstimator, ClassifierMixin): +class TimeSeriesForest(BaseEstimator, UnivariateClassifierMixin): """A random forest classifier for time series. A random forest is a meta estimator that fits a number of decision tree @@ -376,7 +377,7 @@ class TimeSeriesForest(BaseEstimator, ClassifierMixin): Forest for Classification and Feature Extraction". Information Sciences, 239, 142-153 (2013). - .. [2] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. + .. [2] Leo Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. """ # noqa: E501 def __init__(self, diff --git a/pyts/classification/tsbf.py b/pyts/classification/tsbf.py index 1a337fb..626603e 100644 --- a/pyts/classification/tsbf.py +++ b/pyts/classification/tsbf.py @@ -6,12 +6,12 @@ from math import ceil from numba import njit import numpy as np - -from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestClassifier from sklearn.utils.validation import ( check_array, check_is_fitted, check_random_state, check_X_y ) +from ..base import UnivariateClassifierMixin, UnivariateTransformerMixin @njit @@ -67,7 +67,7 @@ def histogram(X, bins, n_bins, n_samples, n_classes): return X_new -class IntervalFeatureExtractor(BaseEstimator, TransformerMixin): +class IntervalFeatureExtractor(BaseEstimator, UnivariateTransformerMixin): """Feature extractor over the intervals of a subsequence. This transformer extracts 3 features from each interval of each @@ -283,7 +283,7 @@ def _check_params(self, X, n_timestamps): n_subsequences, rng) -class TSBF(BaseEstimator, ClassifierMixin): +class TSBF(BaseEstimator, UnivariateClassifierMixin): """Time Series Bag-of-Features algorithm. Parameters @@ -488,11 +488,11 @@ class TSBF(BaseEstimator, ClassifierMixin): References ---------- - .. [R1] M.G. Baydogan, G. Runger and E. Tuv, "A Bag-of-Features Framework + .. [1] M.G. Baydogan, G. Runger and E. Tuv, "A Bag-of-Features Framework to Classify Time Series". IEEE Transactions on Pattern Analysis and Machine Intelligence, 35(11), 2796-2802 (2013). - .. [R2] Leo Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. + .. [2] Leo Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. """ # noqa: E501 def __init__(self, diff --git a/pyts/decomposition/ssa.py b/pyts/decomposition/ssa.py index 4541647..82cb1f7 100644 --- a/pyts/decomposition/ssa.py +++ b/pyts/decomposition/ssa.py @@ -6,8 +6,9 @@ import numpy as np from math import ceil from numba import njit, prange -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin from ..utils.utils import _windowed_view @@ -36,7 +37,7 @@ def _diagonal_averaging(X, n_samples, n_timestamps, window_size, return X_new -class SingularSpectrumAnalysis(BaseEstimator, TransformerMixin): +class SingularSpectrumAnalysis(BaseEstimator, UnivariateTransformerMixin): """Singular Spectrum Analysis. Parameters diff --git a/pyts/image/gaf.py b/pyts/image/gaf.py index cb613d2..f07ab7e 100644 --- a/pyts/image/gaf.py +++ b/pyts/image/gaf.py @@ -6,9 +6,10 @@ import numpy as np from math import ceil from numba import njit, prange -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array from ..approximation import PiecewiseAggregateApproximation +from ..base import UnivariateTransformerMixin from ..preprocessing import MinMaxScaler @@ -28,7 +29,7 @@ def _gadf(X_cos, X_sin, n_samples, image_size): return X_gadf -class GramianAngularField(BaseEstimator, TransformerMixin): +class GramianAngularField(BaseEstimator, UnivariateTransformerMixin): """Gramian Angular Field. Parameters diff --git a/pyts/image/mtf.py b/pyts/image/mtf.py index 9adaad3..53e2103 100644 --- a/pyts/image/mtf.py +++ b/pyts/image/mtf.py @@ -6,8 +6,9 @@ import numpy as np from math import ceil from numba import njit, prange -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin from ..preprocessing import KBinsDiscretizer from ..utils import segmentation @@ -46,7 +47,7 @@ def _aggregated_markov_transition_field(X_mtf, n_samples, image_size, return X_amtf -class MarkovTransitionField(BaseEstimator, TransformerMixin): +class MarkovTransitionField(BaseEstimator, UnivariateTransformerMixin): """Markov Transition Field. Parameters diff --git a/pyts/image/recurrence.py b/pyts/image/recurrence.py index 5404c7c..ea08c6e 100644 --- a/pyts/image/recurrence.py +++ b/pyts/image/recurrence.py @@ -6,8 +6,9 @@ import numpy as np from math import ceil from numpy.lib.stride_tricks import as_strided -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin def _trajectories(X, dimension, time_delay): @@ -20,7 +21,7 @@ def _trajectories(X, dimension, time_delay): return as_strided(X, shape=shape_new, strides=strides_new) -class RecurrencePlot(BaseEstimator, TransformerMixin): # noqa: D207 +class RecurrencePlot(BaseEstimator, UnivariateTransformerMixin): # noqa: D207 r"""Recurrence Plot. A recurrence plot is an image representing the distances between diff --git a/pyts/multivariate/classification/multivariate.py b/pyts/multivariate/classification/multivariate.py index 6f15d13..a1aea13 100644 --- a/pyts/multivariate/classification/multivariate.py +++ b/pyts/multivariate/classification/multivariate.py @@ -5,9 +5,10 @@ import numpy as np from numba import njit, prange -from sklearn.base import BaseEstimator, ClassifierMixin, clone +from sklearn.base import BaseEstimator, clone from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_is_fitted +from ...base import MultivariateClassifierMixin from ..utils import check_3d_array @@ -20,7 +21,7 @@ def _hard_vote(y_pred, weights): return maj -class MultivariateClassifier(BaseEstimator, ClassifierMixin): +class MultivariateClassifier(BaseEstimator, MultivariateClassifierMixin): """Classifier for multivariate time series. It provides a convenient class to classify multivariate time series with diff --git a/pyts/multivariate/image/joint_rp.py b/pyts/multivariate/image/joint_rp.py index b0c762c..0b49872 100644 --- a/pyts/multivariate/image/joint_rp.py +++ b/pyts/multivariate/image/joint_rp.py @@ -4,12 +4,13 @@ # License: BSD-3-Clause import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from ...base import MultivariateTransformerMixin from ...image import RecurrencePlot from ..utils import check_3d_array -class JointRecurrencePlot(BaseEstimator, TransformerMixin): +class JointRecurrencePlot(BaseEstimator, MultivariateTransformerMixin): r"""Joint Recurrence Plot. A recurrence plot is an image representing the distances between diff --git a/pyts/multivariate/transformation/multivariate.py b/pyts/multivariate/transformation/multivariate.py index 2144117..5c4e9a1 100644 --- a/pyts/multivariate/transformation/multivariate.py +++ b/pyts/multivariate/transformation/multivariate.py @@ -5,12 +5,13 @@ import numpy as np from scipy.sparse import csr_matrix, hstack -from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.base import BaseEstimator, clone from sklearn.utils.validation import check_is_fitted +from ...base import MultivariateTransformerMixin from ..utils import check_3d_array -class MultivariateTransformer(BaseEstimator, TransformerMixin): +class MultivariateTransformer(BaseEstimator, MultivariateTransformerMixin): r"""Transformer for multivariate time series. It provides a convenient class to transform multivariate time series with diff --git a/pyts/multivariate/transformation/weasel_muse.py b/pyts/multivariate/transformation/weasel_muse.py index 2f058ef..6eb1909 100644 --- a/pyts/multivariate/transformation/weasel_muse.py +++ b/pyts/multivariate/transformation/weasel_muse.py @@ -6,14 +6,15 @@ import numpy as np from scipy.sparse import csr_matrix, hstack -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.base import clone from sklearn.utils.validation import check_is_fitted +from ...base import MultivariateTransformerMixin from ...transformation import WEASEL from ..utils import check_3d_array -class WEASELMUSE(BaseEstimator, TransformerMixin): +class WEASELMUSE(BaseEstimator, MultivariateTransformerMixin): r"""WEASEL+MUSE algorithm. Parameters diff --git a/pyts/preprocessing/discretizer.py b/pyts/preprocessing/discretizer.py index 4cf786d..9f2d118 100644 --- a/pyts/preprocessing/discretizer.py +++ b/pyts/preprocessing/discretizer.py @@ -6,9 +6,10 @@ import numpy as np from numba import njit, prange from scipy.stats import norm -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array from warnings import warn +from ..base import UnivariateTransformerMixin @njit() @@ -48,7 +49,7 @@ def _digitize(X, bins): return X_binned.astype('int64') -class KBinsDiscretizer(BaseEstimator, TransformerMixin): +class KBinsDiscretizer(BaseEstimator, UnivariateTransformerMixin): """Bin continuous data into intervals sample-wise. Parameters diff --git a/pyts/preprocessing/imputer.py b/pyts/preprocessing/imputer.py index 3d9b399..00b06af 100644 --- a/pyts/preprocessing/imputer.py +++ b/pyts/preprocessing/imputer.py @@ -5,12 +5,13 @@ import numpy as np from scipy.interpolate import interp1d -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.impute import MissingIndicator from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin -class InterpolationImputer(BaseEstimator, TransformerMixin): +class InterpolationImputer(BaseEstimator, UnivariateTransformerMixin): """Impute missing values using interpolation. Parameters diff --git a/pyts/preprocessing/scaler.py b/pyts/preprocessing/scaler.py index fbdefb5..2c8b1c0 100644 --- a/pyts/preprocessing/scaler.py +++ b/pyts/preprocessing/scaler.py @@ -9,9 +9,10 @@ from sklearn.preprocessing import MaxAbsScaler as SklearnMaxAbsScaler from sklearn.preprocessing import RobustScaler as SklearnRobustScaler from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin -class StandardScaler(BaseEstimator, TransformerMixin): +class StandardScaler(BaseEstimator, UnivariateTransformerMixin): """Standardize time series by removing mean and scaling to unit variance. Parameters diff --git a/pyts/preprocessing/transformer.py b/pyts/preprocessing/transformer.py index 27eb114..699b05d 100644 --- a/pyts/preprocessing/transformer.py +++ b/pyts/preprocessing/transformer.py @@ -3,14 +3,15 @@ # Author: Johann Faouzi # License: BSD-3-Clause -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.preprocessing import PowerTransformer as SklearnPowerTransformer from sklearn.preprocessing import (QuantileTransformer as SklearnQuantileTransformer) from sklearn.utils.validation import check_array +from ..base import UnivariateTransformerMixin -class PowerTransformer(BaseEstimator, TransformerMixin): +class PowerTransformer(BaseEstimator, UnivariateTransformerMixin): """Apply a power transform sample-wise to make data more Gaussian-like. Power transforms are a family of parametric, monotonic transformations @@ -109,7 +110,7 @@ def transform(self, X): return X_new -class QuantileTransformer(BaseEstimator, TransformerMixin): +class QuantileTransformer(BaseEstimator, UnivariateTransformerMixin): """Transform samples using quantiles information. This method transforms the samples to follow a uniform or a normal diff --git a/pyts/tests/test_base.py b/pyts/tests/test_base.py new file mode 100644 index 0000000..9387b47 --- /dev/null +++ b/pyts/tests/test_base.py @@ -0,0 +1,65 @@ +"""Testing for base classes.""" + +# Author: Johann Faouzi +# License: BSD-3-Clause + +import numpy as np +import pytest +from sklearn.base import clone + +from pyts.classification import SAXVSM +from pyts.datasets import load_gunpoint, load_basic_motions +from pyts.multivariate.image import JointRecurrencePlot +from pyts.multivariate.classification import MultivariateClassifier +from pyts.approximation import SymbolicFourierApproximation + + +X_uni, _, y_uni, _ = load_gunpoint(return_X_y=True) +X_multi, _, y_multi, _ = load_basic_motions(return_X_y=True) + + +@pytest.mark.parametrize( + 'estimator, X, y', + [(SymbolicFourierApproximation(n_bins=2), X_uni, None), + (SymbolicFourierApproximation(n_bins=2, strategy='entropy'), + X_uni, y_uni)] +) +def test_univariate_transformer_mixin(estimator, X, y): + sfa_1 = clone(estimator) + sfa_2 = clone(estimator) + np.testing.assert_array_equal( + sfa_1.fit_transform(X, y), sfa_2.fit(X, y).transform(X) + ) + + +@pytest.mark.parametrize( + 'estimator, X, y', + [(JointRecurrencePlot(), X_multi, None), + (JointRecurrencePlot(), X_multi, y_multi)] +) +def test_multivariate_transformer_mixin(estimator, X, y): + jrp_1 = clone(estimator) + jrp_2 = clone(estimator) + np.testing.assert_allclose( + jrp_1.fit_transform(X, y), jrp_2.fit(X, y).transform(X) + ) + + +@pytest.mark.parametrize( + 'sample_weight', + [None, np.ones_like(y_uni), np.random.uniform(size=y_uni.size)] +) +def test_univariate_classifier_mixin(sample_weight): + clf = SAXVSM().fit(X_uni, y_uni) + assert isinstance(clf.score(X_uni, y_uni, sample_weight), + (float, np.floating)) + + +@pytest.mark.parametrize( + 'sample_weight', + [None, np.ones(y_multi.size), np.random.uniform(size=y_multi.size)] +) +def test_multivariate_classifier_mixin(sample_weight): + clf = MultivariateClassifier(SAXVSM()).fit(X_multi, y_multi) + assert isinstance(clf.score(X_multi, y_multi, sample_weight), + (float, np.floating)) diff --git a/pyts/transformation/bag_of_patterns.py b/pyts/transformation/bag_of_patterns.py index 6e15f36..0042988 100644 --- a/pyts/transformation/bag_of_patterns.py +++ b/pyts/transformation/bag_of_patterns.py @@ -4,13 +4,14 @@ # License: BSD-3-Clause from scipy.sparse import csr_matrix -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.feature_extraction.text import CountVectorizer from sklearn.utils.validation import check_array, check_is_fitted from ..bag_of_words import BagOfWords +from ..base import UnivariateTransformerMixin -class BagOfPatterns(BaseEstimator, TransformerMixin): +class BagOfPatterns(BaseEstimator, UnivariateTransformerMixin): """Bag-of-patterns representation for time series. This algorithm uses a sliding window to extract subsequences from the diff --git a/pyts/transformation/boss.py b/pyts/transformation/boss.py index 50d5fc2..0e36002 100644 --- a/pyts/transformation/boss.py +++ b/pyts/transformation/boss.py @@ -6,15 +6,16 @@ import numpy as np from math import ceil from scipy.sparse import csr_matrix -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.feature_extraction.text import CountVectorizer from sklearn.utils.validation import check_array, check_is_fitted from sklearn.utils.multiclass import check_classification_targets from ..approximation import SymbolicFourierApproximation +from ..base import UnivariateTransformerMixin from ..utils.utils import _windowed_view -class BOSS(BaseEstimator, TransformerMixin): +class BOSS(BaseEstimator, UnivariateTransformerMixin): """Bag of Symbolic Fourier Approximation Symbols. For each time series, subseries are extracted using a slidind window. diff --git a/pyts/transformation/rocket.py b/pyts/transformation/rocket.py index 7bd49a9..ee36ab1 100644 --- a/pyts/transformation/rocket.py +++ b/pyts/transformation/rocket.py @@ -1,9 +1,14 @@ +"""Code for RandOm Convolutional KErnel Transformation.""" + +# Author: Johann Faouzi +# License: BSD-3-Clause + from numba import njit, prange import numpy as np - -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.utils.validation import ( check_array, check_is_fitted, check_random_state) +from ..base import UnivariateTransformerMixin @njit() @@ -173,7 +178,7 @@ def apply_all_kernels(X, weights, lengths, biases, dilations, paddings): return X_new -class ROCKET(BaseEstimator, TransformerMixin): +class ROCKET(BaseEstimator, UnivariateTransformerMixin): """RandOm Convolutional KErnel Transformation. This algorithm randomly generates a great variety of convolutional kernels diff --git a/pyts/transformation/shapelet_transform.py b/pyts/transformation/shapelet_transform.py index 1468dfd..e583776 100644 --- a/pyts/transformation/shapelet_transform.py +++ b/pyts/transformation/shapelet_transform.py @@ -6,10 +6,11 @@ from numba.typed import List import numpy as np from numpy.lib.stride_tricks import as_strided -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.feature_selection import f_classif, mutual_info_classif from sklearn.utils.validation import (check_array, check_is_fitted, check_random_state, check_X_y) +from ..base import UnivariateTransformerMixin from ..utils.utils import _windowed_view @@ -143,7 +144,7 @@ def _remove_similar_shapelets(scores, start_idx, end_idx): return np.array(kept_idx) -class ShapeletTransform(BaseEstimator, TransformerMixin): +class ShapeletTransform(BaseEstimator, UnivariateTransformerMixin): """Shapelet Transform Algorithm. The Shapelet Transform algorithm extracts the most discriminative diff --git a/pyts/transformation/weasel.py b/pyts/transformation/weasel.py index 9ba8189..d92838e 100644 --- a/pyts/transformation/weasel.py +++ b/pyts/transformation/weasel.py @@ -7,14 +7,15 @@ from scipy.sparse import coo_matrix, csr_matrix, hstack from sklearn.utils.validation import check_array, check_X_y, check_is_fitted from sklearn.utils.multiclass import check_classification_targets -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_selection import chi2 from ..approximation import SymbolicFourierApproximation +from ..base import UnivariateTransformerMixin from ..utils.utils import _windowed_view -class WEASEL(BaseEstimator, TransformerMixin): +class WEASEL(BaseEstimator, UnivariateTransformerMixin): """Word ExtrAction for time SEries cLassification. Parameters