8000 TST check_array tests: Use itertools.product, add bool and object dty… · scikit-learn/scikit-learn@95afdc7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 95afdc7

Browse files
committed
TST check_array tests: Use itertools.product, add bool and object dtypes.
1 parent 1a01ef1 commit 95afdc7

File tree

1 file changed

+49
-51
lines changed

1 file changed

+49
-51
lines changed

sklearn/utils/tests/test_validation.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import assert_array_equal
66
import scipy.sparse as sp
77
from nose.tools import assert_raises, assert_true, assert_false, assert_equal
8+
from itertools import product
89

910
from sklearn.utils import (array2d, as_float_array, atleast2d_or_csr,
1011
atleast2d_or_csc, check_arrays, safe_asarray,
@@ -259,64 +260,61 @@ def test_check_array():
259260
X_F = X_C.copy("F")
260261
X_int = X_C.astype(np.int)
261262
X_float = X_C.astype(np.float)
262-
263 8000 -
for X in [X_C, X_F, X_int, X_float]:
264-
for dtype in [np.int32, np.int, np.float, np.float32, None]:
265-
for order in ['C', 'F', None]:
266-
for copy in [True, False]:
267-
X_checked = check_array(X, dtype=dtype, order=order,
268-
copy=copy)
269-
if dtype is not None:
270-
assert_equal(X_checked.dtype, dtype)
271-
else:
272-
assert_equal(X_checked.dtype, X.dtype)
273-
if order == 'C':
274-
assert_true(X_checked.flags['C_CONTIGUOUS'])
275-
assert_false(X_checked.flags['F_CONTIGUOUS'])
276-
elif order == 'F':
277-
assert_true(X_checked.flags['F_CONTIGUOUS'])
278-
assert_false(X_checked.flags['C_CONTIGUOUS'])
279-
if copy:
280-
assert_false(X is X_checked)
281-
else:
282-
# doesn't copy if it was already good
283-
if (X.dtype == X_checked.dtype and
284-
X_checked.flags['C_CONTIGUOUS'] ==
285-
X.flags['C_CONTIGUOUS'] and
286-
X_checked.flags['F_CONTIGUOUS'] ==
287-
X.flags['F_CONTIGUOUS']):
288-
assert_true(X is X_checked)
263+
Xs = [X_C, X_F, X_int, X_float]
264+
dtypes = [np.int32, np.int, np.float, np.float32, None, np.bool, object]
265+
orders = ['C', 'F', None]
266+
copys = [True, False]
267+
268+
for X, dtype, order, copy in product(Xs, dtypes, orders, copys):
269+
X_checked = check_array(X, dtype=dtype, order=order, copy=copy)
270+
if dtype is not None:
271+
assert_equal(X_checked.dtype, dtype)
272+
else:
273+
assert_equal(X_checked.dtype, X.dtype)
274+
if order == 'C':
275+
assert_true(X_checked.flags['C_CONTIGUOUS'])
276+
assert_false(X_checked.flags['F_CONTIGUOUS'])
277+
elif order == 'F':
278+
assert_true(X_checked.flags['F_CONTIGUOUS'])
279+
assert_false(X_checked.flags['C_CONTIGUOUS'])
280+
if copy:
281+
assert_false(X is X_checked)
282+
else:
283+
# doesn't copy if it was already good
284+
if (X.dtype == X_checked.dtype and
285+
X_checked.flags['C_CONTIGUOUS'] == X.flags['C_CONTIGUOUS']
286+
and X_checked.flags['F_CONTIGUOUS'] == X.flags['F_CONTIGUOUS']):
287+
assert_true(X is X_checked)
289288

290289
# allowed sparse != None
291290
X_csc = sp.csc_matrix(X_C)
292291
X_coo = X_csc.tocoo()
293292
X_dok = X_csc.todok()
294293
X_int = X_csc.astype(np.int)
295294
X_float = X_csc.astype(np.float)
296-
for X in [X_csc, X_coo, X_dok, X_int, X_float]:
297-
for dtype in [np.int32, np.int, np.float, np.float32, None]:
298-
for allowed_sparse in [['csr', 'coo'], ['coo', 'dok']]:
299-
for copy in [True, False]:
300-
X_checked = check_array(X, dtype=dtype,
301-
allowed_sparse=allowed_sparse,
302-
copy=copy)
303-
if dtype is not None:
304-
assert_equal(X_checked.dtype, dtype)
305-
else:
306-
assert_equal(X_checked.dtype, X.dtype)
307-
if X.format in allowed_sparse:
308-
# no change if allowed
309-
assert_equal(X.format, X_checked.format)
310-
else:
311-
# got converted
312-
assert_equal(X_checked.format, allowed_sparse[0])
313-
if copy:
314-
assert_false(X is X_checked)
315-
else:
316-
# doesn't copy if it was already good
317-
if (X.dtype == X_checked.dtype and
318-
X.format == X_checked.format):
319-
assert_true(X is X_checked)
295+
296+
Xs = [X_csc, X_coo, X_dok, X_int, X_float]
297+
allowed_sparses = [['csr', 'coo'], ['coo', 'dok']]
298+
for X, dtype, allowed_sparse, copy in product(Xs, dtypes, allowed_sparses,
299+
copys):
300+
X_checked = check_array(X, dtype=dtype, allowed_sparse=allowed_sparse,
301+
copy=copy)
302+
if dtype is not None:
303+
assert_equal(X_checked.dtype, dtype)
304+
else:
305+
assert_equal(X_checked.dtype, X.dtype)
306+
if X.format in allowed_sparse:
307+
# no change if allowed
308+
assert_equal(X.format, X_checked.format)
309+
else:
310+
# got converted
311+
assert_equal(X_checked.format, allowed_sparse[0])
312+
if copy:
313+
assert_false(X is X_checked)
314+
else:
315+
# doesn't copy if it was already good
316+
if (X.dtype == X_checked.dtype and X.format == X_checked.format):
317+
assert_true(X is X_checked)
320318

321319
# other input formats
322320
# convert lists to arrays

0 commit comments

Comments
 (0)
0