8000 ENH Allowing sparse inputs for prediction in AffinityPropagation (#20… · scikit-learn/scikit-learn@aa86c83 · GitHub
[go: up one dir, main page]

Skip to content

Commit aa86c83

Browse files
authored
ENH Allowing sparse inputs for prediction in AffinityPropagation (#20117)
1 parent 67f6a5c commit aa86c83

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ Changelog
150150
- |Efficiency| :class:`cluster.MiniBatchKMeans` is now faster in multicore
151151
settings. :pr:`17622` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
152152

153+
- |Enhancement| The `predict` and `fit_predict` methods of
154+
:class:`cluster.AffinityPropagation` now accept sparse data type for input
155+
data.
156+
:pr:`20117` by :user:`Venkatachalam Natchiappan <venkyyuvy>`
157+
153158
- |Fix| Fixed a bug in :class:`cluster.MiniBatchKMeans` where the sample
154159
weights were partially ignored when the input is sparse. :pr:`17622` by
155160
:user:`Jérémie du Boisberranger <jeremiedbb>`.

sklearn/cluster/_affinity_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def predict(self, X):
436436
Cluster labels.
437437
"""
438438
check_is_fitted(self)
439-
X = self._validate_data(X, reset=False)
439+
X = self._validate_data(X, reset=False, accept_sparse='csr')
440440
if not hasattr(self, "cluster_centers_"):
441441
raise ValueError("Predict method is not supported when "
442442
"affinity='precomputed'.")

sklearn/cluster/tests/test_affinity_propagation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,25 @@ def test_affinity_propagation_float32():
238238
assert_array_equal(afp.labels_, expected)
239239

240240

241+
def test_sparse_input_for_predict():
242+
# Test to make sure sparse inputs are accepted for predict
243+
# (non-regression test for issue #20049)
244+
af = AffinityPropagation(affinity="euclidean", random_state=42)
245+
af.fit(X)
246+
labels = af.predict(csr_matrix((2, 2)))
247+
assert_array_equal(labels, (2, 2))
248+
249+
250+
def test_sparse_input_for_fit_predict():
251+
# Test to make sure sparse inputs are accepted for fit_predict
252+
# (non-regression test for issue #20049)
253+
af = AffinityPropagation(affinity="euclidean", random_state=42)
254+
rng = np.random.RandomState(42)
255+
X = csr_matrix(rng.randint(0, 2, size=(5, 5)))
256+
labels = af.fit_predict(X)
257+
assert_array_equal(labels, (0, 1, 1, 2, 3))
258+
259+
241260
# TODO: Remove in 1.1
242261
def test_affinity_propagation_pairwise_is_deprecated():
243262
afp = AffinityPropagation(affinity='precomputed')

0 commit comments

Comments
 (0)
0