8000 replace check_arrays implementation · scikit-learn/scikit-learn@5e6a868 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5e6a868

Browse files
committed
replace check_arrays implementation
1 parent 0a604c6 commit 5e6a868

File tree

1 file changed

+20
-53
lines changed

1 file changed

+20
-53
lines changed

sklearn/utils/validation.py

Lines changed: 20 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _num_samples(x):
135135

136136

137137
def check_consistent_length(*arrays):
138-
n_samples = [_num_samples(X) for X in arrays]
138+
n_samples = [_num_samples(X) for X in arrays if X is not None]
139139
uniques = np.unique(n_samples)
140140
if len(uniques) > 1:
141141
raise ValueError("Found arrays with inconsistent numbers of samples: %s"
@@ -192,7 +192,8 @@ def _ensure_sparse_format(array, allowed_sparse, convert_sparse_to, dtype,
192192

193193

194194
def check_array(array, allowed_sparse=None, dtype=None, order=None, copy=False,
195-
force_all_finite=True, convert_sparse_to=None, make_2d=True):
195+
force_all_finite=True, convert_sparse_to=None, make_2d=True,
196+
allow_nd=False):
196197
"""Check everything about an array"""
197198
if isinstance(allowed_sparse, str):
198199
allowed_sparse = [allowed_sparse]
@@ -207,6 +208,9 @@ def check_array(array, allowed_sparse=None, dtype=None, order=None, copy=False,
207208
if make_2d:
208209
array = np.atleast_2d(array)
209210
array = np.array(array, dtype=dtype, order=order, copy=copy)
211+
if not allow_nd and array.ndim >= 3:
212+
raise ValueError("Found array with dim %d. Expected <= 2" %
213+
array.ndim)
210214
if force_all_finite:
211215
_assert_all_finite(array)
212216

@@ -275,60 +279,23 @@ def check_arrays(*arrays, **options):
275279

276280
if len(arrays) == 0:
277281
return None
278-
279-
n_samples = _num_samples(arrays[0])
282+
check_consistent_length(*arrays)
283+
284+
order = 'C' if check_ccontiguous else None
285+
force_finite = not allow_nans
286+
if sparse_format == 'dense':
287+
allow_sparse = None
288+
elif sparse_format is None:
289+
allow_sparse = ['csr', 'csc']
290+
else:
291+
allow_sparse = sparse_format
280292

281293
checked_arrays = []
282294
for array in arrays:
283-
array_orig = array
284-
if array is None:
285-
# special case: ignore optional y=None kwarg pattern
286-
checked_arrays.append(array)
287-
continue
288-
size = _num_samples(array)
289-
290-
if size != n_samples:
291-
raise ValueError("Found array with dim %d. Expected %d"
292-
% (size, n_samples))
293-
294-
if (force_arrays or hasattr(array, "__array__")
295-
or hasattr(array, "shape")):
296-
if sp.issparse(array):
297-
if sparse_format == 'csr':
298-
array = < 10000 span class=pl-s1>array.tocsr()
299-
elif sparse_format == 'csc':
300-
array = array.tocsc()
301-
elif sparse_format == 'dense':
302-
raise TypeError('A sparse matrix was passed, but dense '
303-
'data is required. Use X.toarray() to '
304-
'convert to a dense numpy array.')
305-
if check_ccontiguous:
306-
array.data = np.ascontiguousarray(array.data, dtype=dtype)
307-
elif hasattr(array, 'data'):
308-
array.data = np.asarray(array.data, dtype=dtype)
309-
elif array.dtype != dtype:
310-
# Cast on the required dtype
311-
array = array.astype(dtype)
312-
if not allow_nans:
313-
if hasattr(array, 'data'):
314-
_assert_all_finite(array.data)
315-
else:
316-
# DOK sparse matrices
317-
_assert_all_finite(array.values())
318-
else:
319-
if check_ccontiguous:
320-
array = np.ascontiguousarray(array, dtype=dtype)
321-
elif dtype is not None or force_arrays:
322-
array = np.asarray(array, dtype=dtype)
323-
if not allow_nans:
324-
_assert_all_finite(array)
325-
326-
if force_arrays and not allow_nd and array.ndim >= 3:
327-
raise ValueError("Found array with dim %d. Expected <= 2" %
328-
array.ndim)
329-
330-
if copy and array is array_orig:
331-
array = array.copy()
295+
if (force_arrays or hasattr(array, "__array__") or hasattr(array, "shape")):
296+
array = check_array(array, allow_sparse, dtype, order, copy=copy,
297+
make_2d=False, allow_nd=allow_nd,
298+
force_all_finite=force_finite)
332299
checked_arrays.append(array)
333300

334301
return checked_arrays

0 commit comments

Comments
 (0)
0