From 258dff1276e20c283f3aa3e2d19f88f2112351a5 Mon Sep 17 00:00:00 2001 From: anadege Date: Thu, 24 Nov 2022 20:25:59 +0100 Subject: [PATCH 1/4] Double validation for affinity_propagation public function --- sklearn/cluster/_affinity_propagation.py | 8 +++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 180e37996aa07..dc715107787e2 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -13,7 +13,7 @@ from ..exceptions import ConvergenceWarning from ..base import BaseEstimator, ClusterMixin from ..utils import as_float_array, check_random_state -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.validation import check_is_fitted from ..metrics import euclidean_distances from ..metrics import pairwise_distances_argmin @@ -178,6 +178,12 @@ def _affinity_propagation( # Public API +@validate_params( + { + "S": ["array-like"], + "return_n_iter": ["boolean"], + } +) def affinity_propagation( S, *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index ff42011427b83..63cbe5927924d 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -105,6 +105,7 @@ def _check_function_param_validation( "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", "sklearn.svm.l1_min_c", + "sklearn.cluster.affinity_propagation", ] From f2c1411852de31b129fa96dcb6fbb627339f239e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 9 Dec 2022 17:44:07 +0100 Subject: [PATCH 2/4] TST add the function/class association --- sklearn/tests/test_public_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 63cbe5927924d..17c1238ae61f3 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -105,7 +105,6 @@ def _check_function_param_validation( "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", "sklearn.svm.l1_min_c", - "sklearn.cluster.affinity_propagation", ] @@ -125,6 +124,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"), + ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), ] From ff1ca3a2f5f48a2968c318dd191a5978700b8097 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 27 Dec 2022 19:49:01 +0100 Subject: [PATCH 3/4] remove double validation --- sklearn/cluster/_affinity_propagation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index dc715107787e2..11e9b02b2062b 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -275,13 +275,11 @@ def affinity_propagation( Brendan J. Frey and Delbert Dueck, "Clustering by Passing Messages Between Data Points", Science Feb. 2007 """ - S = as_float_array(S, copy=copy) - estimator = AffinityPropagation( damping=damping, max_iter=max_iter, convergence_iter=convergence_iter, - copy=False, + copy=copy, preference=preference, affinity="precomputed", verbose=verbose, From 3db197c057d1c01042aca40952d9db447ea4e180 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Wed, 28 Dec 2022 00:46:22 +0100 Subject: [PATCH 4/4] lint --- sklearn/cluster/_affinity_propagation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 11e9b02b2062b..586b6c2c905a4 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -12,7 +12,7 @@ from ..exceptions import ConvergenceWarning from ..base import BaseEstimator, ClusterMixin -from ..utils import as_float_array, check_random_state +from ..utils import check_random_state from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.validation import check_is_fitted from ..metrics import euclidean_distances