Skip to content

Commit

Permalink
PERF Implement PairwiseDistancesReduction backend for `KNeighbors.p…
Browse files Browse the repository at this point in the history
…redict_proba` (scikit-learn#24076)

Signed-off-by: Julien Jerphanion <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2023
1 parent 2c4e1d7 commit b6b6f63
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ sklearn/metrics/_dist_metrics.pyx
sklearn/metrics/_dist_metrics.pxd
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx
sklearn/metrics/_pairwise_distances_reduction/_base.pxd
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ Changelog
dissimilarity is not a metric and cannot be supported by the BallTree.
:pr:`25417` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| The performance of :meth:`neighbors.KNeighborsClassifier.predict`
and of :meth:`neighbors.KNeighborsClassifier.predict_proba` has been improved
when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics.
:pr:`24076` by :user:`Meekail Zain <micky774>`, :user:`Julien Jerphanion <jjerphan>`.

:mod:`sklearn.neural_network`
.............................

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ ignore =
sklearn/metrics/_dist_metrics.pxd
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd
sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx
sklearn/metrics/_pairwise_distances_reduction/_base.pxd
sklearn/metrics/_pairwise_distances_reduction/_base.pyx
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ def check_package_status(package, min_version):
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_argkmin_classmode.pyx.tp"],
"language": "c++",
"include_np": True,
"extra_compile_args": ["-std=c++11"],
},
{
"sources": ["_radius_neighbors.pyx.tp", "_radius_neighbors.pxd.tp"],
"language": "c++",
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/_pairwise_distances_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@
BaseDistancesReductionDispatcher,
ArgKmin,
RadiusNeighbors,
ArgKminClassMode,
sqeuclidean_row_norms,
)

__all__ = [
"BaseDistancesReductionDispatcher",
"ArgKmin",
"RadiusNeighbors",
"ArgKminClassMode",
"sqeuclidean_row_norms",
]
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0]
self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0]

@final
cdef void _parallel_on_X_prange_iter_finalize(
self,
ITYPE_t thread_num,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
{{py:

implementation_specific_values = [
# Values are the following ones:
#
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
#
# We also use the float64 dtype and C-type names as defined in
# `sklearn.utils._typedefs` to maintain consistency.
#
('64', 'DTYPE_t', 'DTYPE'),
('32', 'cnp.float32_t', 'np.float32')
]

}}

from cython cimport floating, integral
from cython.parallel cimport parallel, prange
from libcpp.map cimport map as cpp_map, pair as cpp_pair
from libc.stdlib cimport free

cimport numpy as cnp

cnp.import_array()

from ...utils._typedefs cimport ITYPE_t, DTYPE_t
from ...utils._typedefs import ITYPE, DTYPE
import numpy as np
from scipy.sparse import issparse
from sklearn.utils.fixes import threadpool_limits

cpdef enum WeightingStrategy:
uniform = 0
# TODO: Implement the following options, most likely in
# `weighted_histogram_mode`
distance = 1
callable = 2

{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
from ._argkmin cimport ArgKmin{{name_suffix}}
from ._datasets_pair cimport DatasetsPair{{name_suffix}}

cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
"""
{{name_suffix}}bit implementation of ArgKminClassMode.
"""
cdef:
const ITYPE_t[:] class_membership,
const ITYPE_t[:] unique_labels
DTYPE_t[:, :] class_scores
cpp_map[ITYPE_t, ITYPE_t] labels_to_index
WeightingStrategy weight_type

@classmethod
def compute(
cls,
X,
Y,
ITYPE_t k,
weights,
class_membership,
unique_labels,
str metric="euclidean",
chunk_size=None,
dict metric_kwargs=None,
str strategy=None,
):
"""Compute the argkmin reduction with class_membership.

This classmethod is responsible for introspecting the arguments
values to dispatch to the most appropriate implementation of
:class:`ArgKminClassMode{{name_suffix}}`.

This allows decoupling the API entirely from the implementation details
whilst maintaining RAII: all temporarily allocated datastructures necessary
for the concrete implementation are therefore freed when this classmethod
returns.

No instance _must_ directly be created outside of this class method.
"""
# Use a generic implementation that handles most scipy
# metrics by computing the distances between 2 vectors at a time.
pda = ArgKminClassMode{{name_suffix}}(
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
k=k,
chunk_size=chunk_size,
strategy=strategy,
weights=weights,
class_membership=class_membership,
unique_labels=unique_labels,
)

# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in GEMM for instance).
with threadpool_limits(limits=1, user_api="blas"):
if pda.execute_in_parallel_on_Y:
pda._parallel_on_Y()
else:
pda._parallel_on_X()

return pda._finalize_results()

def __init__(
self,
DatasetsPair{{name_suffix}} datasets_pair,
const ITYPE_t[:] class_membership,
const ITYPE_t[:] unique_labels,
chunk_size=None,
strategy=None,
ITYPE_t k=1,
weights=None,
):
super().__init__(
datasets_pair=datasets_pair,
chunk_size=chunk_size,
strategy=strategy,
k=k,
)

if weights == "uniform":
self.weight_type = WeightingStrategy.uniform
elif weights == "distance":
self.weight_type = WeightingStrategy.distance
else:
self.weight_type = WeightingStrategy.callable
self.class_membership = class_membership

self.unique_labels = unique_labels

cdef ITYPE_t idx, neighbor_class_idx
# Map from set of unique labels to their indices in `class_scores`
# Buffer used in building a histogram for one-pass weighted mode
self.class_scores = np.zeros(
(self.n_samples_X, unique_labels.shape[0]), dtype=DTYPE,
)

def _finalize_results(self):
probabilities = np.asarray(self.class_scores)
probabilities /= probabilities.sum(axis=1, keepdims=True)
return probabilities

cdef inline void weighted_histogram_mode(
self,
ITYPE_t sample_index,
ITYPE_t* indices,
DTYPE_t* distances,
) noexcept nogil:
cdef:
ITYPE_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index
DTYPE_t score_incr = 1
# TODO: Implement other WeightingStrategy values
bint use_distance_weighting = (
self.weight_type == WeightingStrategy.distance
)

# Iterate through the sample k-nearest neighbours
for neighbor_rank in range(self.k):
# Absolute indice of the neighbor_rank-th Nearest Neighbors
# in range [0, n_samples_Y)
# TODO: inspect if it worth permuting this condition
# and the for-loop above for improved branching.
if use_distance_weighting:
score_incr = 1 / distances[neighbor_rank]
neighbor_idx = indices[neighbor_rank]
neighbor_class_idx = self.class_membership[neighbor_idx]
self.class_scores[sample_index][neighbor_class_idx] += score_incr
return

cdef void _parallel_on_X_prange_iter_finalize(
self,
ITYPE_t thread_num,
ITYPE_t X_start,
ITYPE_t X_end,
) noexcept nogil:
cdef:
ITYPE_t idx, sample_index
for idx in range(X_end - X_start):
# One-pass top-one weighted mode
# Compute the absolute index in [0, n_samples_X)
sample_index = X_start + idx
self.weighted_histogram_mode(
sample_index,
&self.heaps_indices_chunks[thread_num][idx * self.k],
&self.heaps_r_distances_chunks[thread_num][idx * self.k],
)
return

cdef void _parallel_on_Y_finalize(
self,
) noexcept nogil:
cdef:
ITYPE_t sample_index, thread_num

with nogil, parallel(num_threads=self.chunks_n_threads):
# Deallocating temporary datastructures
for thread_num in prange(self.chunks_n_threads, schedule='static'):
free(self.heaps_r_distances_chunks[thread_num])
free(self.heaps_indices_chunks[thread_num])

for sample_index in prange(self.n_samples_X, schedule='static'):
self.weighted_histogram_mode(
sample_index,
&self.argkmin_indices[sample_index][0],
&self.argkmin_distances[sample_index][0],
)
return

{{endfor}}
Loading

0 comments on commit b6b6f63

Please sign in to comment.