Skip to content

Commit

Permalink
update TimeSeriesKNN
Browse files Browse the repository at this point in the history
transpose numpy after check_X
remove algorirthm option
include capabilities tags
  • Loading branch information
TonyBagnall committed Feb 17, 2021
1 parent bb5e216 commit cef586e
Showing 1 changed file with 10 additions and 37 deletions.
47 changes: 10 additions & 37 deletions sktime/classification/distance_based/_time_series_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class KNeighborsTimeSeriesClassifier(_KNeighborsClassifier, BaseClassifier):

# Capabilities: data types this classifier can handle
capabilities = {
"multivariate": False,
"multivariate": True,
"unequal_length": False,
"missing_values": False,
}
Expand All @@ -115,28 +115,14 @@ def __init__(
self,
n_neighbors=1,
weights="uniform",
algorithm="brute",
metric="dtw",
metric_params=None,
**kwargs
):
if algorithm == "kd_tree":
raise ValueError(
"KNeighborsTimeSeriesClassifier cannot work with kd_tree since kd_tree "
"cannot be used with a callable distance metric and we do not support "
"precalculated distances as yet."
)
if algorithm == "ball_tree":
raise ValueError(
"KNeighborsTimeSeriesClassifier cannot work with ball_tree since "
"ball_tree has a list of hard coded distances it can use, and cannot "
"work with 3-D arrays"
)

self._cv_for_params = False
if metric == "euclidean": # Euclidean will default to the base class distance
metric = euclidean_distance
if metric == "dtw":
elif metric == "dtw":
metric = dtw_distance
elif metric == "dtwcv": # special case to force loocv grid search
# cv in training
Expand Down Expand Up @@ -181,7 +167,7 @@ def __init__(

super(KNeighborsTimeSeriesClassifier, self).__init__(
n_neighbors=n_neighbors,
algorithm=algorithm,
algorithm="brute",
metric=metric,
metric_params=metric_params,
**kwargs
Expand All @@ -203,7 +189,10 @@ def fit(self, X, y):
Target values of shape = [n_samples]
"""
X, y = check_X_y(X, y, enforce_univariate=False, coerce_to_numpy=True)
X, y = check_X_y(X, y, enforce_univariate=not self.capabilities["multivariate"], coerce_to_numpy=True)
# Transpose to work correctly with distance functions
X = X.transpose((0, 2, 1))

y = np.asarray(y)
check_classification_targets(y)
# if internal cv is desired, the relevant flag forces a grid search
Expand Down Expand Up @@ -300,7 +289,9 @@ def kneighbors(self, X, n_neighbors=None, return_distance=True):
Indices of the nearest points in the population matrix.
"""
self.check_is_fitted()
X = check_X(X, enforce_univariate=False, coerce_to_numpy=True)
X = check_X(X, enforce_univariate=not self.capabilities["multivariate"], coerce_to_numpy=True)
# Transpose to work correctly with distance functions
X = X.transpose((0, 2, 1))

if n_neighbors is None:
n_neighbors = self.n_neighbors
Expand Down Expand Up @@ -356,24 +347,6 @@ def kneighbors(self, X, n_neighbors=None, return_distance=True):
n_jobs=n_jobs,
**kwds
)

elif self._fit_method in ["ball_tree", "kd_tree"]:
if issparse(X):
raise ValueError(
"%s does not work with sparse matrices. Densify the data, "
"or set algorithm='brute'" % self._fit_method
)
if LooseVersion(joblib_version) < LooseVersion("0.12"):
# Deal with change of API in joblib
delayed_query = delayed(self._tree.query, check_pickle=False)
parallel_kwargs = {"backend": "threading"}
else:
delayed_query = delayed(self._tree.query)
parallel_kwargs = {"prefer": "threads"}
result = Parallel(n_jobs, **parallel_kwargs)(
delayed_query(X[s], n_neighbors, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
else:
raise ValueError("internal: _fit_method not recognized")

Expand Down

0 comments on commit cef586e

Please sign in to comment.