Skip to content

Commit

Permalink
MAINT PairwiseDistancesReduction: Rename some symbols and files (sc…
Browse files Browse the repository at this point in the history
…ikit-learn#24623)

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
  • Loading branch information
3 people authored Oct 13, 2022
1 parent c0b3385 commit 5b45d1f
Show file tree
Hide file tree
Showing 17 changed files with 128 additions and 96 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pxd
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ ignore =
sklearn/metrics/_pairwise_distances_reduction/_datasets_pair.pyx
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pxd
sklearn/metrics/_pairwise_distances_reduction/_gemm_term_computer.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighborhood.pyx
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd
sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx


[codespell]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"sklearn.metrics._pairwise_distances_reduction._gemm_term_computer",
"sklearn.metrics._pairwise_distances_reduction._base",
"sklearn.metrics._pairwise_distances_reduction._argkmin",
"sklearn.metrics._pairwise_distances_reduction._radius_neighborhood",
"sklearn.metrics._pairwise_distances_reduction._radius_neighbors",
"sklearn.metrics._pairwise_fast",
"sklearn.neighbors._partition_nodes",
"sklearn.tree._splitter",
Expand Down
18 changes: 9 additions & 9 deletions sklearn/metrics/_pairwise_distances_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#
# Dispatchers are meant to be used in the Python code. Under the hood, a
# dispatcher must only define the logic to choose at runtime to the correct
# dtype-specialized :class:`BaseDistanceReductionDispatcher` implementation based
# dtype-specialized :class:`BaseDistancesReductionDispatcher` implementation based
# on the dtype of X and of Y.
#
#
Expand All @@ -46,7 +46,7 @@
#
#
# (base dispatcher)
# BaseDistanceReductionDispatcher
# BaseDistancesReductionDispatcher
# ∆
# |
# |
Expand All @@ -56,8 +56,8 @@
# ArgKmin RadiusNeighbors
# | |
# | |
# | (64bit implem.) |
# | BaseDistanceReducer{32,64} |
# | (float{32,64} implem.) |
# | BaseDistancesReduction{32,64} |
# | ∆ |
# | | |
# | | |
Expand All @@ -74,9 +74,9 @@
# x | | x
# EuclideanArgKmin{32,64} EuclideanRadiusNeighbors{32,64}
#
# For instance :class:`ArgKmin`, dispatches to both :class:`ArgKmin64`
# and :class:`ArgKmin32` if X and Y are both dense NumPy arrays with a `float64`
# or `float32` dtype respectively.
# For instance :class:`ArgKmin` dispatches to:
# - :class:`ArgKmin64` if X and Y are two `float64` array-likes
# - :class:`ArgKmin32` if X and Y are two `float32` array-likes
#
# In addition, if the metric parameter is set to "euclidean" or "sqeuclidean",
# then `ArgKmin{32,64}` further dispatches to `EuclideanArgKmin{32,64}`. For
Expand All @@ -87,14 +87,14 @@


from ._dispatcher import (
BaseDistanceReductionDispatcher,
BaseDistancesReductionDispatcher,
ArgKmin,
RadiusNeighbors,
sqeuclidean_row_norms,
)

__all__ = [
"BaseDistanceReductionDispatcher",
"BaseDistancesReductionDispatcher",
"ArgKmin",
"RadiusNeighbors",
"sqeuclidean_row_norms",
Expand Down
8 changes: 4 additions & 4 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ cnp.import_array()

{{for name_suffix in ['64', '32']}}

from ._base cimport BaseDistanceReducer{{name_suffix}}
from ._base cimport BaseDistancesReduction{{name_suffix}}
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}

cdef class ArgKmin{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):
"""{{name_suffix}}bit implementation of BaseDistanceReducer{{name_suffix}} for the `ArgKmin` reduction."""
cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""float{{name_suffix}} implementation of the ArgKmin."""

cdef:
ITYPE_t k
Expand All @@ -23,7 +23,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):


cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
"""EuclideanDistance-specialized {{name_suffix}}bit implementation of ArgKmin{{name_suffix}}."""
"""EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""
cdef:
GEMMTermComputer{{name_suffix}} gemm_term_computer
const DTYPE_t[::1] X_norm_squared
Expand Down
8 changes: 4 additions & 4 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cnp.import_array()
{{for name_suffix in ['64', '32']}}

from ._base cimport (
BaseDistanceReducer{{name_suffix}},
BaseDistancesReduction{{name_suffix}},
_sqeuclidean_row_norms{{name_suffix}},
)

Expand All @@ -36,8 +36,8 @@ from ._datasets_pair cimport (
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}


cdef class ArgKmin{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):
"""{{name_suffix}}bit implementation of the pairwise-distance reduction BaseDistanceReducer{{name_suffix}}."""
cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""float{{name_suffix}} implementation of the ArgKmin."""

@classmethod
def compute(
Expand Down Expand Up @@ -311,7 +311,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):


cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
"""EuclideanDistance-specialized implementation for ArgKmin{{name_suffix}}."""
"""EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""

@classmethod
def is_usable_for(cls, X, Y, metric) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
from ._datasets_pair cimport DatasetsPair{{name_suffix}}


cdef class BaseDistanceReducer{{name_suffix}}:
cdef class BaseDistancesReduction{{name_suffix}}:
"""
Base {{name_suffix}}bit implementation template of the pairwise-distances reduction
backend.
Base float{{name_suffix}} implementation template of the pairwise-distances
reduction backends.

Implementations inherit from this template and may override the several
defined hooks as needed in order to easily extend functionality with
Expand Down
66 changes: 52 additions & 14 deletions sklearn/metrics/_pairwise_distances_reduction/_base.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(
ITYPE_t d = X.shape[1]
DTYPE_t[::1] squared_row_norms = np.empty(n, dtype=DTYPE)

# To upcast the i-th row of X from 32bit to 64bit
# To upcast the i-th row of X from float32 to float64
vector[vector[DTYPE_t]] X_i_upcast = vector[vector[DTYPE_t]](
num_threads, vector[DTYPE_t](d)
)

with nogil, parallel(num_threads=num_threads):
thread_num = openmp.omp_get_thread_num()
for i in prange(n, schedule='static'):
# Upcasting the i-th row of X from 32bit to 64bit
# Upcasting the i-th row of X from float32 to float64
for j in range(d):
X_i_upcast[thread_num][j] = <DTYPE_t> deref(X_ptr + i * d + j)

Expand All @@ -90,10 +90,10 @@ cpdef DTYPE_t[::1] _sqeuclidean_row_norms32(

from ._datasets_pair cimport DatasetsPair{{name_suffix}}

cdef class BaseDistanceReducer{{name_suffix}}:
cdef class BaseDistancesReduction{{name_suffix}}:
"""
Base {{name_suffix}}bit implementation template of the pairwise-distances reduction
backend.
Base float{{name_suffix}} implementation template of the pairwise-distances
reduction backends.

Implementations inherit from this template and may override the several
defined hooks as needed in order to easily extend functionality with
Expand Down Expand Up @@ -209,7 +209,6 @@ cdef class BaseDistanceReducer{{name_suffix}}:
X_end = X_start + self.X_n_samples_chunk

# Reinitializing thread datastructures for the new X chunk
# If necessary, upcast X[X_start:X_end] to 64bit
self._parallel_on_X_init_chunk(thread_num, X_start, X_end)

for Y_chunk_idx in range(self.Y_n_chunks):
Expand All @@ -219,7 +218,6 @@ cdef class BaseDistanceReducer{{name_suffix}}:
else:
Y_end = Y_start + self.Y_n_samples_chunk

# If necessary, upcast Y[Y_start:Y_end] to 64bit
self._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end,
Y_start, Y_end,
Expand Down Expand Up @@ -280,7 +278,6 @@ cdef class BaseDistanceReducer{{name_suffix}}:
thread_num = _openmp_thread_num()

# Initializing datastructures used in this thread
# If necessary, upcast X[X_start:X_end] to 64bit
self._parallel_on_Y_parallel_init(thread_num, X_start, X_end)

for Y_chunk_idx in prange(self.Y_n_chunks, schedule='static'):
Expand All @@ -290,7 +287,6 @@ cdef class BaseDistanceReducer{{name_suffix}}:
else:
Y_end = Y_start + self.Y_n_samples_chunk

# If necessary, upcast Y[Y_start:Y_end] to 64bit
self._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end,
Y_start, Y_end,
Expand Down Expand Up @@ -326,7 +322,7 @@ cdef class BaseDistanceReducer{{name_suffix}}:
) nogil:
"""Compute the pairwise distances on two chunks of X and Y and reduce them.

This is THE core computational method of BaseDistanceReducer{{name_suffix}}.
This is THE core computational method of BaseDistancesReduction{{name_suffix}}.
This must be implemented in subclasses agnostically from the parallelization
strategies.
"""
Expand Down Expand Up @@ -358,7 +354,19 @@ cdef class BaseDistanceReducer{{name_suffix}}:
ITYPE_t X_start,
ITYPE_t X_end,
) nogil:
"""Initialize datastructures used in a thread given its number."""
"""Initialize datastructures used in a thread given its number.

In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_X_init_chunk(
thread_num, X_start, X_end,
)

to ensure the proper upcast of X[X_start:X_end] to float64 prior
to the reduction with float64 accumulator buffers when X.dtype is
float32.
"""
return

cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
Expand All @@ -371,7 +379,16 @@ cdef class BaseDistanceReducer{{name_suffix}}:
) nogil:
"""Initialize datastructures just before the _compute_and_reduce_distances_on_chunks.

This is eventually used to upcast X[X_start:X_end] to 64bit.
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num,
)

to ensure the proper upcast of Y[Y_start:Y_end] to float64 prior
to the reduction with float64 accumulator buffers when Y.dtype is
float32.
"""
return

Expand Down Expand Up @@ -403,7 +420,19 @@ cdef class BaseDistanceReducer{{name_suffix}}:
ITYPE_t X_start,
ITYPE_t X_end,
) nogil:
"""Initialize datastructures used in a thread given its number."""
"""Initialize datastructures used in a thread given its number.

In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_Y_parallel_init(
thread_num, X_start, X_end,
)

to ensure the proper upcast of X[X_start:X_end] to float64 prior
to the reduction with float64 accumulator buffers when X.dtype is
float32.
"""
return

cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
Expand All @@ -416,7 +445,16 @@ cdef class BaseDistanceReducer{{name_suffix}}:
) nogil:
"""Initialize datastructures just before the _compute_and_reduce_distances_on_chunks.

This is eventually used to upcast Y[Y_start:Y_end] to 64bit.
In this method, EuclideanDistance specialisations of subclass of
BaseDistancesReduction _must_ call:

self.gemm_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
X_start, X_end, Y_start, Y_end, thread_num,
)

to ensure the proper upcast of Y[Y_start:Y_end] to float64 prior
to the reduction with float64 accumulator buffers when Y.dtype is
float32.
"""
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cdef class DatasetsPair{{name_suffix}}:

The handling of parallelization over chunks to compute the distances
and aggregation for several rows at a time is done in dedicated
subclasses of :class:`BaseDistanceReductionDispatcher` that in-turn rely on
subclasses of :class:`BaseDistancesReductionDispatcher` that in-turn rely on
subclasses of :class:`DatasetsPair` for each pair of rows in the data. The
goal is to make it possible to decouple the generic parallelization and
aggregation logic from metric-specific computation as much as possible.
Expand Down
10 changes: 5 additions & 5 deletions sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ArgKmin64,
ArgKmin32,
)
from ._radius_neighborhood import (
from ._radius_neighbors import (
RadiusNeighbors64,
RadiusNeighbors32,
)
Expand Down Expand Up @@ -51,10 +51,10 @@ def sqeuclidean_row_norms(X, num_threads):
)


class BaseDistanceReductionDispatcher:
class BaseDistancesReductionDispatcher:
"""Abstract base dispatcher for pairwise distance computation & reduction.
Each dispatcher extending the base :class:`BaseDistanceReductionDispatcher`
Each dispatcher extending the base :class:`BaseDistancesReductionDispatcher`
dispatcher must implement the :meth:`compute` classmethod.
"""

Expand Down Expand Up @@ -168,7 +168,7 @@ def compute(
"""


class ArgKmin(BaseDistanceReductionDispatcher):
class ArgKmin(BaseDistancesReductionDispatcher):
"""Compute the argkmin of row vectors of X on the ones of Y.
For each row vector of X, computes the indices of k first the rows
Expand Down Expand Up @@ -304,7 +304,7 @@ def compute(
)


class RadiusNeighbors(BaseDistanceReductionDispatcher):
class RadiusNeighbors(BaseDistancesReductionDispatcher):
"""Compute radius-based neighbors for two sets of vectors.
For each row-vector X[i] of the queries X, find all the indices j of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ cdef class GEMMTermComputer{{name_suffix}}:
vector[vector[DTYPE_t]] dist_middle_terms_chunks

{{if upcast_to_float64}}
# Buffers for upcasting chunks of X and Y from 32bit to 64bit
# Buffers for upcasting chunks of X and Y from float32 to float64
vector[vector[DTYPE_t]] X_c_upcast
vector[vector[DTYPE_t]] Y_c_upcast
{{endif}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ cdef class GEMMTermComputer{{name_suffix}}:
self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads)

{{if upcast_to_float64}}
# We populate the buffer for upcasting chunks of X and Y from 32bit to 64bit.
# We populate the buffer for upcasting chunks of X and Y from float32 to float64.
self.X_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
self.Y_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ cdef cnp.ndarray[object, ndim=1] coerce_vectors_to_nd_arrays(
#####################
{{for name_suffix in ['64', '32']}}

from ._base cimport BaseDistanceReducer{{name_suffix}}
from ._base cimport BaseDistancesReduction{{name_suffix}}
from ._gemm_term_computer cimport GEMMTermComputer{{name_suffix}}

cdef class RadiusNeighbors{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):
"""
{{name_suffix}}bit implementation of BaseDistanceReducer{{name_suffix}} for the
`RadiusNeighbors` reduction.
"""
cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""float{{name_suffix}} implementation of the RadiusNeighbors."""

cdef:
DTYPE_t radius
Expand Down Expand Up @@ -82,7 +79,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistanceReducer{{name_suffix}}):


cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix}}):
"""EuclideanDistance-specialized {{name_suffix}}bit implementation for RadiusNeighbors{{name_suffix}}."""
"""EuclideanDistance-specialisation of RadiusNeighbors{{name_suffix}}."""
cdef:
GEMMTermComputer{{name_suffix}} gemm_term_computer
const DTYPE_t[::1] X_norm_squared
Expand Down
Loading

0 comments on commit 5b45d1f

Please sign in to comment.