8000 MNT Deprecate `affinity` in `AgglomerativeClustering` (#23470) · scikit-learn/scikit-learn@a5d50cf · GitHub
[go: up one dir, main page]

Skip to content

Commit a5d50cf

Browse files
Micky774thomasjpfanglemaitre
authored
MNT Deprecate affinity in AgglomerativeClustering (#23470)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent ac52fe1 commit a5d50cf

File tree

5 files changed

+105
-18
lines changed

5 files changed

+105
-18
lines changed

doc/whats_new/v1.2.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ Changelog
9191
`eigen_tol="auto"` in version 1.3.
9292
:pr:`23210` by :user:`Meekail Zain <micky774>`.
9393

94+
- |API| The `affinity` attribute is now deprecated for
95+
:class:`cluster.AgglomerativeClustering` and will be renamed to `metric` in v1.4.
96+
:pr:`23470` by :user:`Meekail Zain <micky774>`.
97+
9498
:mod:`sklearn.datasets`
9599
.......................
96100

examples/cluster/plot_agglomerative_clustering_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def sqr(x):
125125
# Plot clustering results
126126
for index, metric in enumerate(["cosine", "euclidean", "cityblock"]):
127127
model = AgglomerativeClustering(
128-
n_clusters=n_clusters, linkage="average", affinity=metric
128+
n_clusters=n_clusters, linkage="average", metric=metric
129129
)
130130
model.fit(X)
131131
plt.figure()
@@ -134,7 +134,7 @@ def sqr(x):
134134
plt.plot(X[model.labels_ == l].T, c=c, alpha=0.5)
135135
plt.axis("tight")
136136
plt.axis("off")
137-
plt.suptitle("AgglomerativeClustering(affinity=%s)" % metric, size=20)
137+
plt.suptitle("AgglomerativeClustering(metric=%s)" % metric, size=20)
138138

139139

140140
plt.show()

examples/cluster/plot_cluster_comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
)
172172
average_linkage = cluster.AgglomerativeClustering(
173173
linkage="average",
174-
affinity="cityblock",
174+
metric="cityblock",
175175
n_clusters=params["n_clusters"],
176176
connectivity=connectivity,
177177
)

sklearn/cluster/_agglomerative.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..utils import check_array
2424
from ..utils._fast_dict import IntFloatDict
2525
from ..utils.graph import _fix_connected_components
26-
from ..utils._param_validation import Interval, StrOptions
26+
from ..utils._param_validation import Hidden, Interval, StrOptions
2727
from ..utils.validation import check_memory
2828

2929
# mypy error: Module 'sklearn.cluster' has no attribute '_hierarchical_fast'
@@ -760,6 +760,19 @@ class AgglomerativeClustering(ClusterMixin, BaseEstimator):
760760
If "precomputed", a distance matrix (instead of a similarity matrix)
761761
is needed as input for the fit method.
762762
763+
.. deprecated:: 1.2
764+
`affinity` was deprecated in version 1.2 and will be renamed to
765+
`metric` in 1.4.
766+
9E7A 767+
metric : str or callable, default=None
768+
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
769+
"manhattan", "cosine", or "precomputed". If set to `None` then
770+
"euclidean" is used. If linkage is "ward", only "euclidean" is
771+
accepted. If "precomputed", a distance matrix is needed as input for
772+
the fit method.
773+
774+
.. versionadded:: 1.2
775+
763776
memory : str or object with the joblib.Memory interface, default=None
764777
Used to cache the output of the computation of the tree.
765778
By default, no caching is done. If a string is given, it is the
@@ -880,9 +893,15 @@ class AgglomerativeClustering(ClusterMixin, BaseEstimator):
880893
_parameter_constraints = {
881894
"n_clusters": [Interval(Integral, 1, None, closed="left"), None],
882895
"affinity": [
896+
Hidden(StrOptions({"deprecated"})),
883897
StrOptions(set(_VALID_METRICS) | {"precomputed"}),
884898
callable,
885899
],
900+
"metric": [
901+
StrOptions(set(_VALID_METRICS) | {"precomputed"}),
902+
callable,
903+
None,
904+
],
886905
"memory": "no_validation", # TODO
887906
"connectivity": ["array-like", callable, None],
888907
"compute_full_tree": [StrOptions({&qu F438 ot;auto"}), "boolean"],
@@ -895,7 +914,8 @@ def __init__(
895914
self,
896915
n_clusters=2,
897916
*,
898-
affinity="euclidean",
917+
affinity="deprecated", # TODO(1.4): Remove
918+
metric=None, # TODO(1.4): Set to "euclidean"
899919
memory=None,
900920
connectivity=None,
901921
compute_full_tree="auto",
@@ -910,6 +930,7 @@ def __init__(
910930
self.compute_full_tree = compute_full_tree
911931
self.linkage = linkage
912932
self.affinity = affinity
933+
self.metric = metric
913934
self.compute_distances = compute_distances
914935

915936
def fit(self, X, y=None):
@@ -920,7 +941,7 @@ def fit(self, X, y=None):
920941
X : array-like, shape (n_samples, n_features) or \
921942
(n_samples, n_samples)
922943
Training instances to cluster, or distances between instances if
923-
``affinity='precomputed'``.
944+
``metric='precomputed'``.
924945
925946
y : Ignored
926947
Not used, present here for API consistency by convention.
@@ -950,6 +971,24 @@ def _fit(self, X):
950971
"""
951972
memory = check_memory(self.memory)
952973

974+
self._metric = self.metric
975+
# TODO(1.4): Remove
976+
if self.affinity != "deprecated":
977+
if self.metric is not None:
978+
raise ValueError(
979+
"Both `affinity` and `metric` attributes were set. Attribute"
980+
" `affinity` was deprecated in version 1.2 and will be removed in"
981+
" 1.4. To avoid this error, only set the `metric` attribute."
982+
)
983+
warnings.warn(
984+
"Attribute `affinity` was deprecated in version 1.2 and will be removed"
985 10000 +
" in 1.4. Use `metric` instead",
986+
FutureWarning,
987+
)
988+
self._metric = self.affinity
989+
elif self.metric is None:
990+
self._metric = "euclidean"
991+
953992
if not ((self.n_clusters is None) ^ (self.distance_threshold is None)):
954993
raise ValueError(
955994
"Exactly one of n_clusters and "
@@ -962,10 +1001,10 @@ def _fit(self, X):
9621001
"compute_full_tree must be True if distance_threshold is set."
9631002
)
9641003

965-
if self.linkage == "ward" and self.affinity != "euclidean":
1004+
if self.linkage == "ward" and self._metric != "euclidean":
9661005
raise ValueError(
967-
"%s was provided as affinity. Ward can only "
968-
"work with euclidean distances." % (self.affinity,)
1006+
f"{self._metric} was provided as metric. Ward can only "
1007+
"work with euclidean distances."
9691008
)
9701009

9711010
tree_builder = _TREE_BUILDERS[self.linkage]
@@ -998,7 +1037,7 @@ def _fit(self, X):
9981037
kwargs = {}
9991038
if self.linkage != "ward":
10001039
kwargs["linkage"] = self.linkage
1001-
kwargs["affinity"] = self.affinity
1040+
kwargs["affinity"] = self._metric
10021041

10031042
distance_threshold = self.distance_threshold
10041043

@@ -1084,6 +1123,19 @@ class FeatureAgglomeration(
10841123
If "precomputed", a distance matrix (instead of a similarity matrix)
10851124
is needed as input for the fit method.
10861125
1126+
.. deprecated:: 1.2
1127+
`affinity` was deprecated in version 1.2 and will be renamed to
1128+
`metric` in 1.4.
1129+
1130+
metric : str or callable, default=None
1131+
Metric used to compute the linkage. Can be "euclidean", "l1", "l2",
1132+
"manhattan", "cosine", or "precomputed". If set to `None` then
1133+
"euclidean" is used. If linkage is "ward", only "euclidean" is
1134+
accepted. If "precomputed", a distance matrix is needed as input for
1135+
the fit method.
1136+
1137+
.. versionadded:: 1.2
1138+
10871139
memory : str or object with the joblib.Memory interface, default=None
10881140
Used to cache the output of the computation of the tree.
10891141
By default, no caching is done. If a string is given, it is the
@@ -1208,8 +1260,14 @@ class FeatureAgglomeration(
12081260
_parameter_constraints = {
12091261
"n_clusters": [Interval(Integral, 1, None, closed="left"), None],
12101262
"affinity": [
1263+
Hidden(StrOptions({"deprecated"})),
1264+
StrOptions(set(_VALID_METRICS) | {"precomputed"}),
1265+
callable,
1266+
],
1267+
"metric": [
12111268
StrOptions(set(_VALID_METRICS) | {"precomputed"}),
12121269
callable,
1270+
None,
12131271
],
12141272
"memory": "no_validation", # TODO
12151273
"connectivity": ["array-like", callable, None],
@@ -1224,7 +1282,8 @@ def __init__(
12241282
self,
12251283
n_clusters=2,
12261284
*,
1227-
affinity="euclidean",
1285+
affinity="deprecated", # TODO(1.4): Remove
1286+
metric=None, # TODO(1.4): Set to "euclidean"
12281287
memory=None,
12291288
connectivity=None,
12301289
compute_full_tree="auto",
@@ -1240,6 +1299,7 @@ def __init__(
12401299
compute_full_tree=compute_full_tree,
12411300
linkage=linkage,
12421301
affinity=affinity,
1302+
metric=metric,
12431303
distance_threshold=distance_threshold,
12441304
compute_distances=compute_distances,
12451305
)

sklearn/cluster/tests/test_hierarchical.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,24 +241,24 @@ def test_agglomerative_clustering():
241241
clustering = AgglomerativeClustering(
242242
n_clusters=10,
243243
connectivity=connectivity.toarray(),
244-
affinity="manhattan",
244+
metric="manhattan",
245245
linkage="ward",
246246
)
247247
with pytest.raises(ValueError):
248248
clustering.fit(X)
249249

250250
# Test using another metric than euclidean works with linkage complete
251-
for affinity in PAIRED_DISTANCES.keys():
251+
for metric in PAIRED_DISTANCES.keys():
252252
# Compare our (structured) implementation to scipy
253253
clustering = AgglomerativeClustering(
254254
n_clusters=10,
255255
connectivity=np.ones((n_samples, n_samples)),
256-
affinity=affinity,
256+
metric=metric,
257257
linkage="complete",
258258
)
259259
clustering.fit(X)
260260
clustering2 = AgglomerativeClustering(
261-
n_clusters=10, connectivity=None, affinity=affinity, linkage="complete"
261+
n_clusters=10, connectivity=None, metric=metric, linkage="complete"
262262
)
263263
clustering2.fit(X)
264264
assert_almost_equal(
@@ -275,7 +275,7 @@ def test_agglomerative_clustering():
275275
clustering2 = AgglomerativeClustering(
276276
n_clusters=10,
277277
connectivity=connectivity,
278-
affinity="precomputed",
278+
metric="precomputed",
279279
linkage="complete",
280280
)
281281
clustering2.fit(X_dist)
@@ -289,7 +289,7 @@ def test_agglomerative_clustering_memory_mapped():
289289
"""
290290
rng = np.random.RandomState(0)
291291
Xmm = create_memmap_backed_data(rng.randn(50, 100))
292-
AgglomerativeClustering(affinity="euclidean", linkage="single").fit(Xmm)
292+
AgglomerativeClustering(metric="euclidean", linkage="single").fit(Xmm)
293293

294294

295295
def test_ward_agglomeration():
@@ -860,7 +860,7 @@ def test_invalid_shape_precomputed_dist_matrix():
860860
ValueError,
861861
match=r"Distance matrix should be square, got matrix of shape \(5, 3\)",
862862
):
863-
AgglomerativeClustering(affinity="precomputed", linkage="complete").fit(X)
863+
AgglomerativeClustering(metric="precomputed", linkage="complete").fit(X)
864864

865865

866866
def test_precomputed_connectivity_affinity_with_2_connected_components():
@@ -900,3 +900,26 @@ def test_precomputed_connectivity_affinity_with_2_connected_components():
900900

901901
assert_array_equal(clusterer.labels_, clusterer_precomputed.labels_)
902902
assert_array_equal(clusterer.children_, clusterer_precomputed.children_)
903+
904+
905+
# TODO(1.4): Remove
906+
def test_deprecate_affinity():
907+
rng = np.random.RandomState(42)
908+
X = rng.randn(50, 10)
909+
910+
af = AgglomerativeClustering(affinity="euclidean")
911+
msg = (
912+
"Attribute `affinity` was deprecated in version 1.2 and will be removed in 1.4."
913+
" Use `metric` instead"
914+
)
915+
with pytest.warns(FutureWarning, match=msg):
916+
af.fit(X)
917+
with pytest.warns(FutureWarning, match=msg):
918+
af.fit_predict(X)
919+
920+
af = AgglomerativeClustering(metric="euclidean", affinity="euclidean")
921+
msg = "Both `affinity` and `metric` attributes were set. Attribute"
922+
with pytest.raises(ValueError, match=msg):
923+
af.fit(X)
924+
with pytest.raises(ValueError, match=msg):
925+
af.fit_predict(X)

0 commit comments

Comments
 (0)
0