Skip to content

Commit 87bb024

Browse files
jjerphanjeremiedbb
andauthored
TST Ensure that sklearn/metrics/tests/test_pairwise_distances_reduction.py is seed insensitive (scikit-learn#22862)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent 142e388 commit 87bb024

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

sklearn/metrics/tests/test_pairwise_distances_reduction.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
]
3232

3333

34-
def _get_metric_params_list(metric: str, n_features: int):
34+
def _get_metric_params_list(metric: str, n_features: int, seed: int = 1):
3535
"""Return list of dummy DistanceMetric kwargs for tests."""
3636

3737
# Distinguishing on cases not to compute unneeded datastructures.
38-
rng = np.random.RandomState(1)
38+
rng = np.random.RandomState(seed)
3939

4040
if metric == "minkowski":
4141
minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)]
@@ -217,23 +217,22 @@ def test_radius_neighborhood_factory_method_wrong_usages():
217217
)
218218

219219

220-
@pytest.mark.parametrize("seed", range(5))
221220
@pytest.mark.parametrize("n_samples", [100, 1000])
222221
@pytest.mark.parametrize("chunk_size", [50, 512, 1024])
223222
@pytest.mark.parametrize(
224223
"PairwiseDistancesReduction",
225224
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
226225
)
227226
def test_chunk_size_agnosticism(
227+
global_random_seed,
228228
PairwiseDistancesReduction,
229-
seed,
230229
n_samples,
231230
chunk_size,
232231
n_features=100,
233232
dtype=np.float64,
234233
):
235234
# Results should not depend on the chunk size
236-
rng = np.random.RandomState(seed)
235+
rng = np.random.RandomState(global_random_seed)
237236
spread = 100
238237
X = rng.rand(n_samples, n_features).astype(dtype) * spread
239238
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
@@ -263,23 +262,22 @@ def test_chunk_size_agnosticism(
263262
ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices)
264263

265264

266-
@pytest.mark.parametrize("seed", range(5))
267265
@pytest.mark.parametrize("n_samples", [100, 1000])
268266
@pytest.mark.parametrize("chunk_size", [50, 512, 1024])
269267
@pytest.mark.parametrize(
270268
"PairwiseDistancesReduction",
271269
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
272270
)
273271
def test_n_threads_agnosticism(
272+
global_random_seed,
274273
PairwiseDistancesReduction,
275-
seed,
276274
n_samples,
277275
chunk_size,
278276
n_features=100,
279277
dtype=np.float64,
280278
):
281279
# Results should not depend on the number of threads
282-
rng = np.random.RandomState(seed)
280+
rng = np.random.RandomState(global_random_seed)
283281
spread = 100
284282
X = rng.rand(n_samples, n_features).astype(dtype) * spread
285283
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
@@ -308,23 +306,22 @@ def test_n_threads_agnosticism(
308306

309307
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
310308
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
311-
@pytest.mark.parametrize("seed", range(5))
312309
@pytest.mark.parametrize("n_samples", [100, 1000])
313310
@pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics())
314311
@pytest.mark.parametrize(
315312
"PairwiseDistancesReduction",
316313
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
317314
)
318315
def test_strategies_consistency(
316+
global_random_seed,
319317
PairwiseDistancesReduction,
320318
metric,
321319
n_samples,
322-
seed,
323320
n_features=10,
324321
dtype=np.float64,
325322
):
326323

327-
rng = np.random.RandomState(seed)
324+
rng = np.random.RandomState(global_random_seed)
328325
spread = 100
329326
X = rng.rand(n_samples, n_features).astype(dtype) * spread
330327
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
@@ -347,7 +344,9 @@ def test_strategies_consistency(
347344
parameter,
348345
metric=metric,
349346
# Taking the first
350-
metric_kwargs=_get_metric_params_list(metric, n_features)[0],
347+
metric_kwargs=_get_metric_params_list(
348+
metric, n_features, seed=global_random_seed
349+
)[0],
351350
# To be sure to use parallelization
352351
chunk_size=n_samples // 4,
353352
strategy="parallel_on_X",
@@ -360,7 +359,9 @@ def test_strategies_consistency(
360359
parameter,
361360
metric=metric,
362361
# Taking the first
363-
metric_kwargs=_get_metric_params_list(metric, n_features)[0],
362+
metric_kwargs=_get_metric_params_list(
363+
metric, n_features, seed=global_random_seed
364+
)[0],
364365
# To be sure to use parallelization
365366
chunk_size=n_samples // 4,
366367
strategy="parallel_on_Y",
@@ -384,6 +385,7 @@ def test_strategies_consistency(
384385
@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS)
385386
@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y"))
386387
def test_pairwise_distances_argkmin(
388+
global_random_seed,
387389
n_features,
388390
translation,
389391
metric,
@@ -392,7 +394,7 @@ def test_pairwise_distances_argkmin(
392394
k=10,
393395
dtype=np.float64,
394396
):
395-
rng = np.random.RandomState(0)
397+
rng = np.random.RandomState(global_random_seed)
396398
spread = 1000
397399
X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
398400
Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
@@ -443,20 +445,23 @@ def test_pairwise_distances_argkmin(
443445
@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS)
444446
@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y"))
445447
def test_pairwise_distances_radius_neighbors(
448+
global_random_seed,
446449
n_features,
447450
translation,
448451
metric,
449452
strategy,
450453
n_samples=100,
451454
dtype=np.float64,
452455
):
453-
rng = np.random.RandomState(0)
456+
rng = np.random.RandomState(global_random_seed)
454457
spread = 1000
455458
radius = spread * np.log(n_features)
456459
X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
457460
Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
458461

459-
metric_kwargs = _get_metric_params_list(metric, n_features)[0]
462+
metric_kwargs = _get_metric_params_list(
463+
metric, n_features, seed=global_random_seed
464+
)[0]
460465

461466
# Reference for argkmin results
462467
if metric == "euclidean":
@@ -500,18 +505,17 @@ def test_pairwise_distances_radius_neighbors(
500505
)
501506

502507

503-
@pytest.mark.parametrize("seed", range(10))
504508
@pytest.mark.parametrize("n_samples", [100, 1000])
505509
@pytest.mark.parametrize("n_features", [5, 10, 100])
506510
@pytest.mark.parametrize("num_threads", [1, 2, 8])
507511
def test_sqeuclidean_row_norms(
508-
seed,
512+
global_random_seed,
509513
n_samples,
510514
n_features,
511515
num_threads,
512516
dtype=np.float64,
513517
):
514-
rng = np.random.RandomState(seed)
518+
rng = np.random.RandomState(global_random_seed)
515519
spread = 100
516520
X = rng.rand(n_samples, n_features).astype(dtype) * spread
517521

0 commit comments

Comments
 (0)