8000 FIX check_array's accept_sparse param now takes true/false/str/list, … · paulha/scikit-learn@e4aa310 · GitHub
[go: up one dir, main page]

Skip to content

Commit e4aa310

Browse files
jkarnopaulha
authored andcommitted
FIX check_array's accept_sparse param now takes true/false/str/list, but not None (scikit-learn#7937)
1 parent 41e388a commit e4aa310

File tree

3 files changed

+99
-28
lines changed

3 files changed

+99
-28
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ Enhancements
9898
- Added ability to set ``n_jobs`` parameter to :func:`pipeline.make_union`.
9999
A ``TypeError`` will be raised for any other kwargs. :issue:`8028`
100100
by :user:`Alexander Booth <alexandercbooth>`.
101+
102+
- Added type checking to the ``accept_sparse`` parameter in
103+
:mod:`sklearn.utils.validation` methods. This parameter now accepts only
104+
boolean, string, or list/tuple of strings. ``accept_sparse=None`` is deprecated
105+
and should be replaced by ``accept_sparse=False``.
106+
:issue:`7880` by :user:`Josh Karnofsky <jkarno>`.
101107

102108
- :class:`model_selection.GridSearchCV`, :class:`model_selection.RandomizedSearchCV`
103109
and :func:`model_selection.cross_val_score` now allow estimators with callable

sklearn/utils/tests/test_validation.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,48 @@ def test_check_array_dtype_warning():
321321
assert_equal(X_checked.format, 'csr')
322322

323323

324+
def test_check_array_accept_sparse_type_exception():
325+
X = [[1, 2], [3, 4]]
326+
X_csr = sp.csr_matrix(X)
327+
invalid_type = SVR()
328+
329+
msg = ("A sparse matrix was passed, but dense data is required. "
330+
"Use X.toarray() to convert to a dense numpy array.")
331+
assert_raise_message(TypeError, msg,
332+
check_array, X_csr, accept_sparse=False)
333+
assert_raise_message(TypeError, msg,
334+
check_array, X_csr, accept_sparse=None)
335+
336+
msg = ("Parameter 'accept_sparse' should be a string, "
337+
"boolean or list of strings. You provided 'accept_sparse={}'.")
338+
assert_raise_message(ValueError, msg.format(invalid_type),
339+
check_array, X_csr, accept_sparse=invalid_type)
340+
341+
msg = ("When providing 'accept_sparse' as a tuple or list, "
342+
"it must contain at least one string value.")
343+
assert_raise_message(ValueError, msg.format([]),
344+
check_array, X_csr, accept_sparse=[])
345+
assert_raise_message(ValueError, msg.format(()),
346+
check_array, X_csr, accept_sparse=())
347+
348+
msg = "'SVR' object"
349+
assert_raise_message(TypeError, msg,
350+
check_array, X_csr, accept_sparse=[invalid_type])
351+
352+
# Test deprecation of 'None'
353+
assert_warns(DeprecationWarning, check_array, X, accept_sparse=None)
354+
355+
356+
def test_check_array_accept_sparse_no_exception():
357+
X = [[1, 2], [3, 4]]
358+
X_csr = sp.csr_matrix(X)
359+
360+
check_array(X_csr, accept_sparse=True)
361+
check_array(X_csr, accept_sparse='csr')
362+
check_array(X_csr, accept_sparse=['csr'])
363+
check_array(X_csr, accept_sparse=('csr',))
364+
365+
324366
def test_check_array_min_samples_and_features_messages():
325367
# empty list is considered 2D by default:
326368
msg = "0 feature(s) (shape=(1, 0)) while a minimum of 1 is required."

sklearn/utils/validation.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,40 +200,55 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
200200
spmatrix : scipy sparse matrix
201201
Input to validate and convert.
202202
203-
accept_sparse : string, list of string or None (default=None)
203+
accept_sparse : string, boolean or list/tuple of strings
204204
String[s] representing allowed sparse matrix formats ('csc',
205-
'csr', 'coo', 'dok', 'bsr', 'lil', 'dia'). None means that sparse
206-
matrix input will raise an error. If the input is sparse but not in
207-
the allowed format, it will be converted to the first listed format.
205+
'csr', 'coo', 'dok', 'bsr', 'lil', 'dia'). If the input is sparse but
206+
not in the allowed format, it will be converted to the first listed
207+
format. True allows the input to be any format. False means
208+
that a sparse matrix input will raise an error.
208209
209-
dtype : string, type or None (default=none)
210+
dtype : string, type or None
210211
Data type of result. If None, the dtype of the input is preserved.
211212
212-
copy : boolean (default=False)
213+
copy : boolean
213214
Whether a forced copy will be triggered. If copy=False, a copy might
214215
be triggered by a conversion.
215216
216-
force_all_finite : boolean (default=True)
217+
force_all_finite : boolean
217218
Whether to raise an error on np.inf and np.nan in X.
218219
219220
Returns
220221
-------
221222
spmatrix_converted : scipy sparse matrix.
222223
Matrix that is ensured to have an allowed type.
223224
"""
224-
if accept_sparse in [None, False]:
225-
raise TypeError('A sparse matrix was passed, but dense '
226-
'data is required. Use X.toarray() to '
227-
'convert to a dense numpy array.')
228225
if dtype is None:
229226
dtype = spmatrix.dtype
230227

231228
changed_format = False
232-
if (isinstance(accept_sparse, (list, tuple))
233-
and spmatrix.format not in accept_sparse):
234-
# create new with correct sparse
235-
spmatrix = spmatrix.asformat(accept_sparse[0])
236-
changed_format = True
229+
230+
if isinstance(accept_sparse, six.string_types):
231+
accept_sparse = [accept_sparse]
232+
233+
if accept_sparse is False:
234+
raise TypeError('A sparse matrix was passed, but dense '
235+
'data is required. Use X.toarray() to '
236+
'convert to a dense numpy array.')
237+
elif isinstance(accept_sparse, (list, tuple)):
238+
if len(accept_sparse) == 0:
239+
raise ValueError("When providing 'accept_sparse' "
240+
"as a tuple or list, it must contain at "
241+
"least one string value.")
242+
# ensure correct sparse format
243+
if spmatrix.format not in accept_sparse:
244+
# create new with correct sparse
245+
spmatrix = spmatrix.asformat(accept_sparse[0])
246+
changed_format = True
247+
elif accept_sparse is not True:
248+
# any other type
249+
raise ValueError("Parameter 'accept_sparse' should be a string, "
250+
"boolean or list of strings. You provided "
251+
"'accept_sparse={}'.".format(accept_sparse))
237252

238253
if dtype != spmatrix.dtype:
239254
# convert dtype
@@ -251,7 +266,7 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
251266
return spmatrix
252267

253268

254-
def check_array(array, accept_sparse=None, dtype="numeric", order=None,
269+
def check_array(array, accept_sparse=False, dtype="numeric", order=None,
255270
copy=False, force_all_finite=True, ensure_2d=True,
256271
allow_nd=False, ensure_min_samples=1, ensure_min_features=1,
257272
warn_on_dtype=False, estimator=None):
@@ -266,11 +281,12 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
266281
array : object
267282
Input object to check / convert.
268283
269-
accept_sparse : string, list of string or None (default=None)
284+
accept_sparse : string, boolean or list/tuple of strings (default=False)
270285
String[s] representing allowed sparse matrix formats, such as 'csc',
271-
'csr', etc. None means that sparse matrix input will raise an error.
272-
If the input is sparse but not in the allowed format, it will be
273-
converted to the first listed format.
286+
'csr', etc. If the input is sparse but not in the allowed format,
287+
it will be converted to the first listed format. True allows the input
288+
to be any format. False means that a sparse matrix input will
289+
raise an error.
274290
275291
dtype : string, type, list of types or None (default="numeric")
276292
Data type of result. If None, the dtype of the input is preserved.
@@ -321,8 +337,14 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
321337
X_converted : object
322338
The converted and validated X.
323339
"""
324-
if isinstance(accept_sparse, str):
325-
accept_sparse = [accept_sparse]
340+
# accept_sparse 'None' deprecation check
341+
if accept_sparse is None:
342+
warnings.warn(
343+
"Passing 'None' to parameter 'accept_sparse' in methods "
344+
"check_array and check_X_y is deprecated in version 0.19 "
345+
"and will be removed in 0.21. Use 'accept_sparse=False' "
346+
" instead.", DeprecationWarning)
347+
accept_sparse = False
326348

327349
# store whether originally we wanted numeric dtype
328350
dtype_numeric = dtype == "numeric"
@@ -406,7 +428,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
406428
return array
407429

408430

409-
def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
431+
def check_X_y(X, y, accept_sparse=False, dtype="numeric", order=None,
410432
copy=False, force_all_finite=True, ensure_2d=True,
411433
allow_nd=False, multi_output=False, ensure_min_samples=1,
412434
ensure_min_features=1, y_numeric=False,
@@ -427,11 +449,12 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None,
427449
y : nd-array, list or sparse matrix
428450
Labels.
429451
430-
accept_sparse : string, list of string or None (default=None)
452+
accept_sparse : string, boolean or list of string (default=False)
431453
String[s] representing allowed sparse matrix formats, such as 'csc',
432-
'csr', etc. None means that sparse matrix input will raise an error.
433-
If the input is sparse but not in the allowed format, it will be
434-
converted to the first listed format.
454+
'csr', etc. If the input is sparse but not in the allowed format,
455+
it will be converted to the first listed format. True allows the input
456+
to be any format. False means that a sparse matrix input will
457+
raise an error.
435458
436459
dtype : string, type, list of types or None (default="numeric")
437460
Data type of result. If None, the dtype of the input is preserved.

0 commit comments

Comments
 (0)
0