forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PERF Implement
PairwiseDistancesReduction
backend for `KNeighbors.p…
…redict_proba` (scikit-learn#24076) Signed-off-by: Julien Jerphanion <[email protected]> Co-authored-by: Julien Jerphanion <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
- Loading branch information
1 parent
2c4e1d7
commit b6b6f63
Showing
12 changed files
with
648 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
208 changes: 208 additions & 0 deletions
208
sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
{{py: | ||
|
||
implementation_specific_values = [ | ||
# Values are the following ones: | ||
# | ||
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE | ||
# | ||
# We also use the float64 dtype and C-type names as defined in | ||
# `sklearn.utils._typedefs` to maintain consistency. | ||
# | ||
('64', 'DTYPE_t', 'DTYPE'), | ||
('32', 'cnp.float32_t', 'np.float32') | ||
] | ||
|
||
}} | ||
|
||
from cython cimport floating, integral | ||
from cython.parallel cimport parallel, prange | ||
from libcpp.map cimport map as cpp_map, pair as cpp_pair | ||
from libc.stdlib cimport free | ||
|
||
cimport numpy as cnp | ||
|
||
cnp.import_array() | ||
|
||
from ...utils._typedefs cimport ITYPE_t, DTYPE_t | ||
from ...utils._typedefs import ITYPE, DTYPE | ||
import numpy as np | ||
from scipy.sparse import issparse | ||
from sklearn.utils.fixes import threadpool_limits | ||
|
||
cpdef enum WeightingStrategy: | ||
uniform = 0 | ||
# TODO: Implement the following options, most likely in | ||
# `weighted_histogram_mode` | ||
distance = 1 | ||
callable = 2 | ||
|
||
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} | ||
from ._argkmin cimport ArgKmin{{name_suffix}} | ||
from ._datasets_pair cimport DatasetsPair{{name_suffix}} | ||
|
||
cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}): | ||
""" | ||
{{name_suffix}}bit implementation of ArgKminClassMode. | ||
""" | ||
cdef: | ||
const ITYPE_t[:] class_membership, | ||
const ITYPE_t[:] unique_labels | ||
DTYPE_t[:, :] class_scores | ||
cpp_map[ITYPE_t, ITYPE_t] labels_to_index | ||
WeightingStrategy weight_type | ||
|
||
@classmethod | ||
def compute( | ||
cls, | ||
X, | ||
Y, | ||
ITYPE_t k, | ||
weights, | ||
class_membership, | ||
unique_labels, | ||
str metric="euclidean", | ||
chunk_size=None, | ||
dict metric_kwargs=None, | ||
str strategy=None, | ||
): | ||
"""Compute the argkmin reduction with class_membership. | ||
|
||
This classmethod is responsible for introspecting the arguments | ||
values to dispatch to the most appropriate implementation of | ||
:class:`ArgKminClassMode{{name_suffix}}`. | ||
|
||
This allows decoupling the API entirely from the implementation details | ||
whilst maintaining RAII: all temporarily allocated datastructures necessary | ||
for the concrete implementation are therefore freed when this classmethod | ||
returns. | ||
|
||
No instance _must_ directly be created outside of this class method. | ||
""" | ||
# Use a generic implementation that handles most scipy | ||
# metrics by computing the distances between 2 vectors at a time. | ||
pda = ArgKminClassMode{{name_suffix}}( | ||
datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), | ||
k=k, | ||
chunk_size=chunk_size, | ||
strategy=strategy, | ||
weights=weights, | ||
class_membership=class_membership, | ||
unique_labels=unique_labels, | ||
) | ||
|
||
# Limit the number of threads in second level of nested parallelism for BLAS | ||
# to avoid threads over-subscription (in GEMM for instance). | ||
with threadpool_limits(limits=1, user_api="blas"): | ||
if pda.execute_in_parallel_on_Y: | ||
pda._parallel_on_Y() | ||
else: | ||
pda._parallel_on_X() | ||
|
||
return pda._finalize_results() | ||
|
||
def __init__( | ||
self, | ||
DatasetsPair{{name_suffix}} datasets_pair, | ||
const ITYPE_t[:] class_membership, | ||
const ITYPE_t[:] unique_labels, | ||
chunk_size=None, | ||
strategy=None, | ||
ITYPE_t k=1, | ||
weights=None, | ||
): | ||
super().__init__( | ||
datasets_pair=datasets_pair, | ||
chunk_size=chunk_size, | ||
strategy=strategy, | ||
k=k, | ||
) | ||
|
||
if weights == "uniform": | ||
self.weight_type = WeightingStrategy.uniform | ||
elif weights == "distance": | ||
self.weight_type = WeightingStrategy.distance | ||
else: | ||
self.weight_type = WeightingStrategy.callable | ||
self.class_membership = class_membership | ||
|
||
self.unique_labels = unique_labels | ||
|
||
cdef ITYPE_t idx, neighbor_class_idx | ||
# Map from set of unique labels to their indices in `class_scores` | ||
# Buffer used in building a histogram for one-pass weighted mode | ||
self.class_scores = np.zeros( | ||
(self.n_samples_X, unique_labels.shape[0]), dtype=DTYPE, | ||
) | ||
|
||
def _finalize_results(self): | ||
probabilities = np.asarray(self.class_scores) | ||
probabilities /= probabilities.sum(axis=1, keepdims=True) | ||
return probabilities | ||
|
||
cdef inline void weighted_histogram_mode( | ||
self, | ||
ITYPE_t sample_index, | ||
ITYPE_t* indices, | ||
DTYPE_t* distances, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index | ||
DTYPE_t score_incr = 1 | ||
# TODO: Implement other WeightingStrategy values | ||
bint use_distance_weighting = ( | ||
self.weight_type == WeightingStrategy.distance | ||
) | ||
|
||
# Iterate through the sample k-nearest neighbours | ||
for neighbor_rank in range(self.k): | ||
# Absolute indice of the neighbor_rank-th Nearest Neighbors | ||
# in range [0, n_samples_Y) | ||
# TODO: inspect if it worth permuting this condition | ||
# and the for-loop above for improved branching. | ||
if use_distance_weighting: | ||
score_incr = 1 / distances[neighbor_rank] | ||
neighbor_idx = indices[neighbor_rank] | ||
neighbor_class_idx = self.class_membership[neighbor_idx] | ||
self.class_scores[sample_index][neighbor_class_idx] += score_incr | ||
return | ||
|
||
cdef void _parallel_on_X_prange_iter_finalize( | ||
self, | ||
ITYPE_t thread_num, | ||
ITYPE_t X_start, | ||
ITYPE_t X_end, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t idx, sample_index | ||
for idx in range(X_end - X_start): | ||
# One-pass top-one weighted mode | ||
# Compute the absolute index in [0, n_samples_X) | ||
sample_index = X_start + idx | ||
self.weighted_histogram_mode( | ||
sample_index, | ||
&self.heaps_indices_chunks[thread_num][idx * self.k], | ||
&self.heaps_r_distances_chunks[thread_num][idx * self.k], | ||
) | ||
return | ||
|
||
cdef void _parallel_on_Y_finalize( | ||
self, | ||
) noexcept nogil: | ||
cdef: | ||
ITYPE_t sample_index, thread_num | ||
|
||
with nogil, parallel(num_threads=self.chunks_n_threads): | ||
# Deallocating temporary datastructures | ||
for thread_num in prange(self.chunks_n_threads, schedule='static'): | ||
free(self.heaps_r_distances_chunks[thread_num]) | ||
free(self.heaps_indices_chunks[thread_num]) | ||
|
||
for sample_index in prange(self.n_samples_X, schedule='static'): | ||
self.weighted_histogram_mode( | ||
sample_index, | ||
&self.argkmin_indices[sample_index][0], | ||
&self.argkmin_distances[sample_index][0], | ||
) | ||
return | ||
|
||
{{endfor}} |
Oops, something went wrong.