forked from johannfaouzi/pyts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace scikit-learn mixin classes with pyts mixin classes (johannfao…
- Loading branch information
1 parent
a5a9407
commit 51a55c2
Showing
31 changed files
with
292 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,14 +3,15 @@ | |
# Author: Johann Faouzi <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.base import BaseEstimator | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.utils.validation import check_is_fitted | ||
from .dft import DiscreteFourierTransform | ||
from .mcb import MultipleCoefficientBinning | ||
from ..base import UnivariateTransformerMixin | ||
|
||
|
||
class SymbolicFourierApproximation(BaseEstimator, TransformerMixin): | ||
class SymbolicFourierApproximation(BaseEstimator, UnivariateTransformerMixin): | ||
"""Symbolic Fourier Approximation. | ||
Parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
"""Base classes for all estimators.""" | ||
|
||
# Author: Johann Faouzi <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
from sklearn.metrics import accuracy_score | ||
|
||
|
||
class UnivariateTransformerMixin: | ||
"""Mixin class for all univariate transformers in pyts.""" | ||
|
||
def fit_transform(self, X, y=None, **fit_params): | ||
"""Fit to data, then transform it. | ||
Fits transformer to `X` and `y` with optional parameters `fit_params` | ||
and returns a transformed version of `X`. | ||
Parameters | ||
---------- | ||
X : array-like, shape = (n_samples, n_timestamps) | ||
Univariate time series. | ||
y : None or array-like, shape = (n_samples,) (default = None) | ||
Target values (None for unsupervised transformations). | ||
**fit_params : dict | ||
Additional fit parameters. | ||
Returns | ||
------- | ||
X_new : array | ||
Transformed array. | ||
""" # noqa: E501 | ||
if y is None: | ||
# fit method of arity 1 (unsupervised transformation) | ||
return self.fit(X, **fit_params).transform(X) | ||
else: | ||
# fit method of arity 2 (supervised transformation) | ||
return self.fit(X, y, **fit_params).transform(X) | ||
|
||
|
||
class MultivariateTransformerMixin: | ||
"""Mixin class for all multivariate transformers in pyts.""" | ||
|
||
def fit_transform(self, X, y=None, **fit_params): | ||
"""Fit to data, then transform it. | ||
Fits transformer to `X` and `y` with optional parameters `fit_params` | ||
and returns a transformed version of `X`. | ||
Parameters | ||
---------- | ||
X : array-like, shape = (n_samples, n_features, n_timestamps) | ||
Multivariate time series. | ||
y : None or array-like, shape = (n_samples,) (default = None) | ||
Target values (None for unsupervised transformations). | ||
**fit_params : dict | ||
Additional fit parameters. | ||
Returns | ||
------- | ||
X_new : array | ||
Transformed array. | ||
""" # noqa: E501 | ||
if y is None: | ||
# fit method of arity 1 (unsupervised transformation) | ||
return self.fit(X, **fit_params).transform(X) | ||
else: | ||
# fit method of arity 2 (supervised transformation) | ||
return self.fit(X, y, **fit_params).transform(X) | ||
|
||
|
||
class UnivariateClassifierMixin: | ||
"""Mixin class for all univariate classifiers in pyts.""" | ||
|
||
def score(self, X, y, sample_weight=None): | ||
""" | ||
Return the mean accuracy on the given test data and labels. | ||
Parameters | ||
---------- | ||
X : array-like, shape = (n_samples, n_timestamps) | ||
Univariate time series. | ||
y : array-like, shape = (n_samples,) | ||
True labels for `X`. | ||
sample_weight : None or array-like, shape = (n_samples,) (default = None) | ||
Sample weights. | ||
Returns | ||
------- | ||
score : float | ||
Mean accuracy of ``self.predict(X)`` with regards to `y`. | ||
""" # noqa: E501 | ||
return accuracy_score(y, self.predict(X), sample_weight=sample_weight) | ||
|
||
|
||
class MultivariateClassifierMixin: | ||
"""Mixin class for all multivariate classifiers in pyts.""" | ||
|
||
def score(self, X, y, sample_weight=None): | ||
""" | ||
Return the mean accuracy on the given test data and labels. | ||
Parameters | ||
---------- | ||
X : array-like, shape = (n_samples, n_features, n_timestamps) | ||
Multivariate time series. | ||
y : array-like, shape = (n_samples,) | ||
True labels for `X`. | ||
sample_weight : None or array-like, shape = (n_samples,) (default = None) | ||
Sample weights. | ||
Returns | ||
------- | ||
score : float | ||
Mean accuracy of ``self.predict(X)`` with regards to `y`. | ||
""" # noqa: E501 | ||
return accuracy_score(y, self.predict(X), sample_weight=sample_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,16 +3,17 @@ | |
# Author: Johann Faouzi <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
from sklearn.base import BaseEstimator, ClassifierMixin | ||
from sklearn.base import BaseEstimator | ||
from sklearn.neighbors import KNeighborsClassifier as SklearnKNN | ||
from sklearn.preprocessing import LabelEncoder | ||
from sklearn.utils.validation import check_X_y, check_is_fitted | ||
from ..base import UnivariateClassifierMixin | ||
from ..metrics import boss, dtw, sakoe_chiba_band, itakura_parallelogram | ||
from ..metrics.dtw import (_dtw_classic, _dtw_region, _dtw_fast, | ||
_dtw_multiscale) | ||
|
||
|
||
class KNeighborsClassifier(BaseEstimator, ClassifierMixin): | ||
class KNeighborsClassifier(BaseEstimator, UnivariateClassifierMixin): | ||
"""k-nearest neighbors classifier. | ||
Parameters | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.