8000 fix check_array default · scikit-learn/scikit-learn@c3ecfed · GitHub
[go: up one dir, main page]

Skip to content

Commit c3ecfed

Browse files
committed
fix check_array default
1 parent 5e6a868 commit c3ecfed

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

sklearn/utils/validation.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,35 +165,35 @@ def _sparse_matrix_constructor(string_format):
165165
raise ValueError("Don't know how to construct a sparse matrix of type %s" % string_format)
166166

167167

168-
def _ensure_sparse_format(array, allowed_sparse, convert_sparse_to, dtype,
168+
def _ensure_sparse_format(spmatrix, allowed_sparse, convert_sparse_to, dtype,
169169
order, copy, force_all_finite):
170170
if allowed_sparse is None:
171171
raise TypeError('A sparse matrix was passed, but dense '
172172
'data is required. Use X.toarray() to '
173173
'convert to a dense numpy array.')
174-
sparse_type = _get_sparse_type_string(array)
174+
sparse_type = _get_sparse_type_string(spmatrix)
175175
if sparse_type in allowed_sparse:
176176
# correct type
177-
if dtype == array.dtype or dtype is None:
177+
if dtype == spmatrix.dtype or dtype is None:
178178
# correct dtype
179179
if copy:
180-
array = array.copy()
180+
spmatrix = spmatrix.copy()
181181
else:
182182
# convert dtype
183-
array = array.astype(dtype)
183+
spmatrix = spmatrix.astype(dtype)
184184
else:
185185
# create new
186-
array = _sparse_matrix_constructor(convert_sparse_to)(array, copy=copy,
187-
dtype=dtype)
186+
spmatrix = _sparse_matrix_constructor(convert_sparse_to)(
187+
spmatrix, copy=copy, dtype=dtype)
188188
if force_all_finite:
189-
_assert_all_finite(array.data)
190-
array.data = np.array(array.data, copy=False, order=order)
191-
return array
189+
_assert_all_finite(spmatrix.data)
190+
spmatrix.data = np.array(spmatrix.data, copy=False, order=order)
191+
return spmatrix
192192

193193

194194
def check_array(array, allowed_sparse=None, dtype=None, order=None, copy=False,
195195
force_all_finite=True, convert_sparse_to=None, make_2d=True,
196-
allow_nd=False):
196+
allow_nd=True):
197197
"""Check everything about an array"""
198198
if isinstance(allowed_sparse, str):
199199
allowed_sparse = [allowed_sparse]
@@ -292,7 +292,11 @@ def check_arrays(*arrays, **options):
292292

293293
checked_arrays = []
294294
for array in arrays:
295-
if (force_arrays or hasattr(array, "__array__") or hasattr(array, "shape")):
295+
if array is None:
296+
checked_arrays.append(array)
297+
continue
298+
299+
if force_arrays or sp.issparse(array):
296300
array = check_array(array, allow_sparse, dtype, order, copy=copy,
297301
make_2d=False, allow_nd=allow_nd,
298302
force_all_finite=force_finite)

0 commit comments

Comments
 (0)
0