10000 POC 32bit datasets support for `PairwiseDistancesReduction` by jjerphan · Pull Request #22590 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

POC 32bit datasets support for PairwiseDistancesReduction #22590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8ae1cb6
MAINT Generate DistanceMetrics for 32bit vectors
jjerphan Feb 23, 2022
487c0f1
MAINT Generate PairwiseDistancesReduction for 32bit and 64bit
jjerphan Feb 23, 2022
451f3d5
MAINT Make PairwiseDistances{Reduction,ArgKmin} facades
jjerphan Feb 24, 2022
2c78948
Merge branch 'main' into distance-metrics-32bit
jjerphan Feb 25, 2022
327d249
TST Fix test_sqeuclidean_row_norms
jjerphan Feb 25, 2022
baf2fc6
TST Add fixture to test quasi-equality for 32bit
jjerphan Feb 25, 2022
3c80b40
MAINT Do not route 32bit specialize implementation yet
jjerphan Feb 25, 2022
3d9d565
TST Adapt DistanceMetrics tests for 32bit
jjerphan Feb 25, 2022
eb0c65e
fixup! TST Adapt DistanceMetrics tests for 32bit
jjerphan Feb 25, 2022 8000
42eba72
MAINT Upcast buffers to 64bit when and where needed
jjerphan Feb 25, 2022
b41c8aa
MAINT Improve upcast
jjerphan Feb 26, 2022
65ebc92
CLN Improve imports and fix duplicated ignored files
jjerphan Feb 26, 2022
ffb08ac
MAINT Do not warn if Y_norm_squared is passed via metric_kwargs
jjerphan Feb 28, 2022
059128f
DOC Update whats_new entry
jjerphan Feb 28, 2022
7ccd37e
TST Add a test for dtype agnosticism
jjerphan Feb 28, 2022
7645ba3
MAINT Accumulate using float64
jjerphan Mar 2, 2022
2f56a49
MAINT 32bit support for DistanceMetric
jjerphan Mar 11, 2022
dd50629
MAINT Remove generated pyx file
jjerphan Mar 17, 2022
545326f
MAINT Do not generate 32bit version of DatasetsPair for now
jjerphan Mar 17, 2022
5da38e8
Merge branch 'main' into maint/distance-metrics-32bit
jjerphan Mar 30, 2022
e8b8344
Fix typo
jjerphan Mar 30, 2022
2d104a4
TST Adapt error messages
jjerphan Apr 1, 2022
857e20f
Merge branch 'maint/distance-metrics-32bit' into distance-metrics-32bit
jjerphan Apr 4, 2022
5d261b1
MAINT Reorganise upcast w.r.t GEMMTermComputer introduction
jjerphan Apr 4, 2022
26b3839
MAINT Correctly allocate buffer for upcasting
jjerphan Apr 14, 2022
3f6f2c6
Merge branch 'main' into distance-metrics-32bit
jjerphan May 24, 2022
7b0bcd3
TST Update tests
jjerphan May 29, 2022
f0fc839
MAINT Correctly resize buffers for upcasting
jjerphan May 29, 2022
12771ed
Merge branch 'main' into distance-metrics-32bit
jjerphan May 29, 2022
5fc225e
Merge branch 'main' into distance-metrics-32bit
jjerphan Jun 1, 2022
cef57b1
MAINT Document and reduce diff but not the logic
jjerphan Jun 1, 2022
cbef7f1
DEBUG Propagate sort_results
jjerphan Jun 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
10000 Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ sklearn/utils/_weight_vector.pxd
sklearn/linear_model/_sag_fast.pyx
sklearn/metrics/_dist_metrics.pyx
sklearn/metrics/_dist_metrics.pxd
sklearn/metrics/_pairwise_distances_reduction.pyx
3 changes: 3 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ Changelog


- |Efficiency| Low-level routines for reductions on pairwise distances
for dense float32 and float64 datasets have been refactored.
The following functions and estimators now benefit from improved performances,
in particular on multi-cores machines:
for dense float64 datasets have been refactored. The following functions
and estimators now benefit from improved performances in terms of hardware
scalability and speed-ups:
Expand Down
13 changes: 7 additions & 6 deletions sklearn/metrics/_dist_metrics.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ cdef class DistanceMetric{{name_suffix}}:

cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1

{{endfor}}

######################################################################
# DatasetsPair base class
cdef class DatasetsPair:
cdef DistanceMetric distance_metric
cdef class DatasetsPair{{name_suffix}}:
cdef DistanceMetric{{name_suffix}} distance_metric

cdef ITYPE_t n_samples_X(self) nogil

Expand All @@ -116,8 +115,10 @@ cdef class DatasetsPair:
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil


cdef class DenseDenseDatasetsPair(DatasetsPair):
cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
cdef:
const DTYPE_t[:, ::1] X
const DTYPE_t[:, ::1] Y
const {{DTYPE_t}}[:, ::1] X
const {{DTYPE_t}}[:, ::1] Y
ITYPE_t d

{{endfor}}
25 changes: 15 additions & 10 deletions sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,11 +1170,10 @@ cdef class PyFuncDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
raise TypeError("Custom distance function must accept two "
"vectors and return a float.")

{{endfor}}

######################################################################
# Datasets Pair Classes
cdef class DatasetsPair:
cdef class DatasetsPair{{name_suffix}}:
"""Abstract class which wraps a pair of datasets (X, Y).

This class allows computing distances between a single pair of rows of
Expand Down Expand Up @@ -1211,7 +1210,7 @@ cdef class DatasetsPair:
Y,
str metric="euclidean",
dict metric_kwargs=None,
) -> DatasetsPair:
) -> DatasetsPair{{name_suffix}}:
"""Return the DatasetsPair implementation for the given arguments.

Parameters
Expand Down Expand Up @@ -1241,14 +1240,14 @@ cdef class DatasetsPair:
The suited DatasetsPair implementation.
"""
cdef:
DistanceMetric distance_metric = DistanceMetric.get_metric(
DistanceMetric{{name_suffix}} distance_metric = DistanceMetric{{name_suffix}}.get_metric(
metric,
**(metric_kwargs or {})
)

if not(X.dtype == Y.dtype == np.float64):
if not(X.dtype == Y.dtype and X.dtype in DatasetsPair{{name_suffix}}.valid_dtypes()):
raise ValueError(
f"Only 64bit float datasets are supported at this time, "
f"Only np.float64 and np.float32 datasets are supported at this time, "
f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}."
)

Expand All @@ -1260,11 +1259,15 @@ cdef class DatasetsPair:
if issparse(X) or issparse(Y):
raise ValueError("Only dense datasets are supported for X and Y.")

return DenseDenseDatasetsPair(X, Y, distance_metric)
return DenseDenseDatasetsPair{{name_suffix}}(X, Y, distance_metric)

def __init__(self, DistanceMetric distance_metric):
def __init__(self, DistanceMetric{{name_suffix}} distance_metric):
self.distance_metric = distance_metric

@classmethod
def valid_dtypes(cls):
return (np.float64, np.float32)

cdef ITYPE_t n_samples_X(self) nogil:
"""Number of samples in X."""
# This is a abstract method.
Expand All @@ -1289,7 +1292,7 @@ cdef class DatasetsPair:
return -1

@final
cdef class DenseDenseDatasetsPair(DatasetsPair):
cdef class DenseDenseDatasetsPair{{name_suffix}}(DatasetsPair{{name_suffix}}):
"""Compute distances between row vectors of two arrays.

Parameters
Expand All @@ -1305,7 +1308,7 @@ cdef class DenseDenseDatasetsPair(DatasetsPair):
between two row vectors of (X, Y).
"""

def __init__(self, X, Y, DistanceMetric distance_metric):
def __init__(self, X, Y, DistanceMetric{{name_suffix}} distance_metric):
super().__init__(distance_metric)
# Arrays have already been checked
self.X = X
Expand All @@ -1331,3 +1334,5 @@ cdef class DenseDenseDatasetsPair(DatasetsPair):
return self.distance_metric.dist(&self.X[i, 0],
&self.Y[j, 0],
self.d)

{{endfor}}
Loading
0