Skip to content

Commit

Permalink
ENH Add metric_params parameter to TSNE (scikit-learn#22685)
Browse files Browse the repository at this point in the history
Co-authored-by: jeannedionisi <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
3 people authored Mar 5, 2022
1 parent 269bdb9 commit 723b707
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ Changelog
in eigen_solvers `lobpcg` and `amg` to improve their numerical stability.
:pr:`21565` by :user:`Andrew Knyazev <lobpcg>`.

- |Enhancement| added `metric_params` to :class:`manifold.TSNE` constructor for
additional parameters of distance metric to use in optimization.
:pr:`21805` by :user:`Jeanne Dionisi <jeannedionisi>` and :pr:`22685` by
:user:`Meekail Zain <micky774>`.

:mod:`sklearn.metrics`
......................

Expand Down
11 changes: 10 additions & 1 deletion sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,11 @@ class TSNE(BaseEstimator):
the distance between them. The default is "euclidean" which is
interpreted as squared euclidean distance.
metric_params : dict, default=None
Additional keyword arguments for the metric function.
.. versionadded:: 1.1
init : {'random', 'pca'} or ndarray of shape (n_samples, n_components), \
default='random'
Initialization of embedding. Possible options are 'random', 'pca',
Expand Down Expand Up @@ -744,6 +749,7 @@ def __init__(
n_iter_without_progress=300,
min_grad_norm=1e-7,
metric="euclidean",
metric_params=None,
init="warn",
verbose=0,
random_state=None,
Expand All @@ -760,6 +766,7 @@ def __init__(
self.n_iter_without_progress = n_iter_without_progress
self.min_grad_norm = min_grad_norm
self.metric = metric
self.metric_params = metric_params
self.init = init
self.verbose = verbose
self.random_state = random_state
Expand Down Expand Up @@ -885,8 +892,9 @@ def _fit(self, X, skip_num_points=0):
# Also, Euclidean is slower for n_jobs>1, so don't set here
distances = pairwise_distances(X, metric=self.metric, squared=True)
else:
metric_params_ = self.metric_params or {}
distances = pairwise_distances(
X, metric=self.metric, n_jobs=self.n_jobs
X, metric=self.metric, n_jobs=self.n_jobs, **metric_params_
)

if np.any(distances < 0):
Expand Down Expand Up @@ -921,6 +929,7 @@ def _fit(self, X, skip_num_points=0):
n_jobs=self.n_jobs,
n_neighbors=n_neighbors,
metric=self.metric,
metric_params=self.metric_params,
)
t0 = time()
knn.fit(X)
Expand Down
31 changes: 31 additions & 0 deletions sklearn/manifold/tests/test_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,37 @@ def test_tsne_n_jobs(method):
assert_allclose(X_tr_ref, X_tr)


# TODO: Remove filterwarnings in 1.2
@pytest.mark.filterwarnings("ignore:.*TSNE will change.*:FutureWarning")
def test_tsne_with_mahalanobis_distance():
"""Make sure that method_parameters works with mahalanobis distance."""
random_state = check_random_state(0)
n_samples, n_features = 300, 10
X = random_state.randn(n_samples, n_features)
default_params = {
"perplexity": 40,
"n_iter": 250,
"learning_rate": "auto",
"n_components": 3,
"random_state": 0,
}

tsne = TSNE(metric="mahalanobis", **default_params)
msg = "Must provide either V or VI for Mahalanobis distance"
with pytest.raises(ValueError, match=msg):
tsne.fit_transform(X)

precomputed_X = squareform(pdist(X, metric="mahalanobis"), checks=True)
X_trans_expected = TSNE(metric="precomputed", **default_params).fit_transform(
precomputed_X
)

X_trans = TSNE(
metric="mahalanobis", metric_params={"V": np.cov(X.T)}, **default_params
).fit_transform(X)
assert_allclose(X_trans, X_trans_expected)


@pytest.mark.filterwarnings("ignore:The PCA initialization in TSNE will change")
# FIXME: remove in 1.3 after deprecation of `square_distances`
def test_tsne_deprecation_square_distances():
Expand Down

0 comments on commit 723b707

Please sign in to comment.