8000 MAINT Parameter validation for sklearn.metrics.pairwise_distances (#2… · scikit-learn/scikit-learn@1584ec2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1584ec2

Browse files
ashah002glemaitrejeremiedbb
authored
MAINT Parameter validation for sklearn.metrics.pairwise_distances (#25515)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 371c921 commit 1584ec2

File tree

3 files changed

+13
-16
lines changed

3 files changed

+13
-16
lines changed

sklearn/metrics/pairwise.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,6 +2034,16 @@ def pairwise_distances_chunked(
20342034
yield D_chunk
20352035

20362036

2037+
@validate_params(
2038+
{
2039+
"X": ["array-like", "sparse matrix"],
2040+
"Y": ["array-like", "sparse matrix", None],
2041+
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
2042+
"n_jobs": [Integral, None],
2043+
"force_all_finite": ["boolean", StrOptions({"allow-nan"})],
2044+
},
2045+
prefer_skip_nested_validation=True,
2046+
)
20372047
def pairwise_distances(
20382048
X, Y=None, metric="euclidean", *, n_jobs=None, force_all_finite=True, **kwds
20392049
):
@@ -2081,13 +2091,13 @@ def pairwise_distances(
20812091
20822092
Parameters
20832093
----------
2084-
X : ndarray of shape (n_samples_X, n_samples_X) or \
2094+
X : {array-like, sparse matrix} of shape (n_samples_X, n_samples_X) or \
20852095
(n_samples_X, n_features)
20862096
Array of pairwise distances between samples, or a feature array.
20872097
The shape of the array should be (n_samples_X, n_samples_X) if
20882098
metric == "precomputed" and (n_samples_X, n_features) otherwise.
20892099
2090-
Y : ndarray of shape (n_samples_Y, n_features), default=None
2100+
Y : {array-like, sparse matrix} of shape (n_samples_Y, n_features), default=None
20912101
An optional second feature array. Only allowed if
20922102
metric != "precomputed".
20932103
@@ -2149,16 +2159,6 @@ def pairwise_distances(
21492159
paired_distances : Computes the distances between corresponding elements
21502160
of two arrays.
21512161
"""
2152-
if (
2153-
metric not in _VALID_METRICS
2154-
and not callable(metric)
2155-
and metric != "precomputed"
2156-
):
2157-
raise ValueError(
2158-
"Unknown metric %s. Valid metrics are %s, or 'precomputed', or a callable"
2159-
% (metric, _VALID_METRICS)
2160-
)
2161-
21622162
if metric == "precomputed":
21632163
X, _ = check_pairwise_arrays(
21642164
X, Y, precomputed=True, force_all_finite=force_all_finite

sklearn/metrics/tests/test_pairwise.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ def test_pairwise_distances(global_dtype):
200200
with pytest.raises(TypeError):
201201
pairwise_distances(X, Y_sparse, metric="minkowski")
202202

203-
# Test that a value error is raised if the metric is unknown
204-
with pytest.raises(ValueError):
205-
pairwise_distances(X, Y, metric="blah")
206-
207203

208204
# TODO(1.4): Remove test when `sum_over_features` parameter is removed
209205
@pytest.mark.parametrize("sum_over_features", [True, False])

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def _check_function_param_validation(
250250
"sklearn.metrics.pairwise.polynomial_kernel",
251251
"sklearn.metrics.pairwise.rbf_kernel",
252252
"sklearn.metrics.pairwise.sigmoid_kernel",
253+
"sklearn.metrics.pairwise_distances",
253254
"sklearn.metrics.pairwise_distances_argmin",
254255
"sklearn.metrics.precision_recall_curve",
255256
"sklearn.metrics.precision_recall_fscore_support",

0 commit comments

Comments
 (0)
0