8000 Allowing sparse inputs in MeanShift.predict · Issue #20733 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Allowing sparse inputs in MeanShift.predict #20733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
milana2 opened this issue Aug 11, 2021 · 5 comments
Closed

Allowing sparse inputs in MeanShift.predict #20733

milana2 opened this issue Aug 11, 2021 · 5 comments

Comments

@milana2
Copy link
Contributor
milana2 commented Aug 11, 2021

Describe the bug

MeanShift.predict should work with a sparse X according to the documentation, but the code throws an exception when a sparse matrix is passed. Apologies if this is a non-issue or the issue has been fixed.

This is the same issue as in AffinityPropagation.predict (issue #20049) that was fixed by PR #20117.

Steps/Code to Reproduce

MeanShift.predict with a sparse X throws an exception:

>>> from sklearn.cluster import MeanShift
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
>>> import scipy.sparse
>>> clustering.predict(scipy.sparse.csr_matrix([[0, 0], [5, 5]]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/cluster/_mean_shift.py", line 466, in predict
    X = self._validate_data(X, reset=False) #,accept_sparse=['csr', 'csc', 'coo']) #ANA
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/base.py", line 421, in _validate_data
    X = check_array(X, **check_params)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 593, in check_array
    array = _ensure_sparse_format(array, accept_sparse=accept_sparse,
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/utils/validation.py", line 360, in _ensure_sparse_format
    raise TypeError('A sparse matrix was passed, but dense '
TypeError: A sparse matrix was passed, but dense data is required. Use X.toarray() to convert to a dense numpy array.

Expected Results

If we change sklearn/cluster/_mean_shift.py#L466 into

X = self._validate_data(X, accept_sparse = 'csr', reset=False)

or

X = self._validate_data(X, accept_sparse = ['csr', 'csc', 'coo'], reset=False)

then it works:

>>> from sklearn.cluster import MeanShift
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
>>> import scipy.sparse
>>> clustering.predict(scipy.sparse.csr_matrix([[0, 0], [5, 5]]))
array([1, 0])

Actual Results

The exception shown above.

Versions

System:
    python: 3.9.2 (v3.9.2:1a79785e3e, Feb 19 2021, 09:06:10)  [Clang 6.0 (clang-600.0.57)]
executable: /Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9
   machine: macOS-10.11.6-x86_64-i386-64bit

Python dependencies:
          pip: 21.2.3
   setuptools: 49.2.1
      sklearn: 0.24.1
        numpy: 1.19.5
        scipy: 1.6.1
       Cython: None
       pandas: 1.2.4
   matplotlib: None
       joblib: 1.0.1
threadpoolctl: 2.1.0

Built with OpenMP: False
@thomasjpfan
Copy link
Member

I agree this should be updated. In 0.23.X, we did support sparse matrices:

def predict(self, X):
"""Predict the closest cluster each sample in X belongs to.
Parameters
----------
X : {array-like, sparse matrix}, shape=[n_samples, n_features]
New data to predict.
Returns
-------
labels : array, shape [n_samples,]
Index of the cluster each sample belongs to.
"""
check_is_fitted(self)
return pairwise_distances_argmin(X, self.cluster_centers_)

where pairwise_distances_argmin checked for accept_sparse='csr'.

@milana2 Would you be interested with opening a PR with your suggested fix?

@milana2
Copy link
Contributor Author
milana2 commented Aug 16, 2021

Yes, I'll go ahead and open a PR. Thanks!

@kurchi1205
Copy link

I want to make my first contribution . Can I work on this issue ?

@kurchi1205
Copy link

@milana2 Is this only for the predict function of MeanShift?

@glemaitre
Copy link
Member

Was fixed in the associated issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants
0