8000 ENH make check_array accept several dtypes · scikit-learn/scikit-learn@0222c8b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0222c8b

Browse files
committed
ENH make check_array accept several dtypes
1 parent 417011d commit 0222c8b

File tree

4 files changed

+80
-47
lines changed

4 files changed

+80
-47
lines changed

sklearn/preprocessing/data.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@
1515
from ..base import BaseEstimator, TransformerMixin
1616
from ..externals import six
1717
from ..utils import check_array
18-
from ..utils import warn_if_not_float
1918
from ..utils.extmath import row_norms
20-
from ..utils.fixes import (astype,
21-
combinations_with_replacement as combinations_w_r,
22-
bincount, isclose)
19+
from ..utils.fixes import combinations_with_replacement as combinations_w_r
2320
from ..utils.sparsefuncs_fast import (inplace_csr_row_normalize_l1,
2421
inplace_csr_row_normalize_l2)
2522
from ..utils.sparsefuncs import (inplace_column_scale, mean_variance_axis)
26-
from ..utils.validation import check_is_fitted
23+
from ..utils.validation import check_is_fitted, FLOAT_DTYPES
2724

2825
zip = six.moves.zip
2926
map = six.moves.map
@@ -114,8 +111,9 @@ def scale(X, axis=0, with_mean=True, with_std=True, copy=True):
114111
scaling using the ``Transformer`` API (e.g. as part of a preprocessing
115112
:class:`sklearn.pipeline.Pipeline`)
116113
"""
117-
X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False)
118-
warn_if_not_float(X, estimator='The scale function')
114+
X = check_array(X, accept_sparse='csr', copy=copy, ensure_2d=False,
115+
warn_on_dtype=True, estimator='the scale function',
116+
dtype=FLOAT_DTYPES)
119117
if sparse.issparse(X):
120118
if with_mean:
121119
raise ValueError(
@@ -223,8 +221,8 @@ def fit(self, X, y=None):
223221
The data used to compute the per-feature minimum and maximum
224222
used for later scaling along the features axis.
225223
"""
226-
X = check_array(X, copy=self.copy, ensure_2d=False)
227-
warn_if_not_float(X, estimator=self)
224+
X = check_array(X, copy=self.copy, ensure_2d=False, warn_on_dtype=True,
225+
estimator=self, dtype=FLOAT_DTYPES)
228226
feature_range = self.feature_range
229227
if feature_range[0] >= feature_range[1]:
230228
raise ValueError("Minimum of desired feature range must be smaller"
@@ -345,10 +343,8 @@ def fit(self, X, y=None):
345343
used for later scaling along the features axis.
346344
"""
347345
X = check_array(X, accept_sparse='csr', copy=self.copy,
348-
ensure_2d=False, dtype=None)
349-
if warn_if_not_float(X, estimator=self):
350-
X = check_array(X, accept_sparse=True, copy=False,
351-
dtype=np.float)
346+
ensure_2d=False, warn_on_dtype=True,
347+
estimator=self, dtype=FLOAT_DTYPES)
352348
if sparse.issparse(X):
353349
if self.with_mean:
354350
raise ValueError(
@@ -380,10 +376,9 @@ def transform(self, X, y=None, copy=None):
380376

381377
copy = copy if copy is not None else self.copy
382378
X = check_array(X, accept_sparse='csr', copy=copy,
383-
ensure_2d=False, dtype=None)
384-
if warn_if_not_float(X, estimator=self):
385-
X = check_array(X, accept_sparse=True, copy=False,
386-
dtype=np.float)
379+
ensure_2d=False, warn_on_dtype=True,
380+
estimator=self, dtype=FLOAT_DTYPES)
381+
387382
if sparse.issparse(X):
388383
if self.with_mean:
389384
raise ValueError(
@@ -602,8 +597,8 @@ def normalize(X, norm='l2', axis=1, copy=True):
602597
else:
603598
raise ValueError("'%d' is not a supported axis" % axis)
604599

605-
X = check_array(X, sparse_format, copy=copy)
606-
warn_if_not_float(X, 'The normalize function')
600+
X = check_array(X, sparse_format, copy=copy, warn_on_dtype=True,
601+
estimator='the normalize function', dtype=FLOAT_DTYPES)
607602
if axis == 0:
608603
X = X.T
609604

sklearn/preprocessing/tests/test_data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.preprocessing.data import MinMaxScaler
3030
from sklearn.preprocessing.data import add_dummy_feature
3131
from sklearn.preprocessing.data import PolynomialFeatures
32+
from sklearn.utils.validation import DataConversionWarning
3233

3334
from sklearn import datasets
3435

@@ -499,12 +500,12 @@ def test_warning_scaling_integers():
499500
X = np.array([[1, 2, 0],
500501
[0, 0, 0]], dtype=np.uint8)
501502

502-
w = "assumes floating point values as input, got uint8"
503+
w = "Data with input dtype uint8 was converted to float64"
503504

504505
clean_warning_registry()
505-
assert_warns_message(UserWarning, w, scale, X)
506-
assert_warns_message(UserWarning, w, StandardScaler().fit, X)
507-
assert_warns_message(UserWarning, w, MinMaxScaler().fit, X)
506+
assert_warns_message(DataConversionWarning, w, scale, X)
507+
assert_warns_message(DataConversionWarning, w, StandardScaler().fit, X)
508+
assert_warns_message(DataConversionWarning, w, MinMaxScaler().fit, X)
508509

509510

510511
def test_normalizer_l1():

sklearn/utils/tests/test_validation.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,16 @@ def test_check_array_dtype_stability():
241241
def test_check_array_dtype_warning():
242242
X_int_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
243243
X_float64 = np.asarray(X_int_list, dtype=np.float64)
244+
X_float32 = np.asarray(X_int_list, dtype=np.float32)
244245
X_int64 = np.asarray(X_int_list, dtype=np.int64)
245246
X_csr_float64 = sp.csr_matrix(X_float64)
247+
X_csr_float32 = sp.csr_matrix(X_float32)
248+
X_csc_float32 = sp.csc_matrix(X_float32)
246249
X_csc_int32 = sp.csc_matrix(X_int64, dtype=np.int32)
247250
y = [0, 0, 1]
248251
integer_data = [X_int64, X_csc_int32]
249252
float64_data = [X_float64, X_csr_float64]
253+
float32_data = [X_float32, X_csr_float32, X_csc_float32]
250254
for X in integer_data:
251255
X_checked = assert_no_warnings(check_array, X, dtype=np.float64,
252256
accept_sparse=True)
@@ -260,19 +264,18 @@ def test_check_array_dtype_warning():
260264
# Check that the warning message includes the name of the Estimator
261265
X_checked = assert_warns_message(DataConversionWarning,
262266
'SomeEstimator',
263-
check_array, X, dtype=np.float64,
267+
check_array, X,
268+
dtype=[np.float64, np.float32],
264269
accept_sparse=True,
265270
warn_on_dtype=True,
266271
estimator='SomeEstimator')
267272
assert_equal(X_checked.dtype, np.float64)
268273

269-
X_checked, y_checked = assert_warns_message(DataConversionWarning,
270-
'SomeEstimator',
271-
check_X_y, X, y,
272-
dtype=np.float64,
273-
accept_sparse=True,
274-
warn_on_dtype=True,
275-
estimator='SomeEstimator')
274+
X_checked, y_checked = assert_warns_message(
275+
DataConversionWarning, 'KNeighborsClassifier',
276+
check_X_y, X, y, dtype=np.float64, accept_sparse=True,
277+
warn_on_dtype=True, estimator=KNeighborsClassifier())
278+
276279
assert_equal(X_checked.dtype, np.float64)
277280

278281
for X in float64_data:
@@ -283,7 +286,27 @@ def test_check_array_dtype_warning():
283286
accept_sparse=True, warn_on_dtype=False)
284287
assert_equal(X_checked.dtype, np.float64)
285288

286-
289+
for X in float32_data:
290+
X_checked = assert_no_warnings(check_array, X,
291+
dtype=[np.float64, np.float32],
292+
accept_sparse=True)
293+
assert_equal(X_checked.dtype, np.float32)
294+
assert_true(X_checked is X)
295+
296+
X_checked = assert_no_warnings(check_array, X,
297+
dtype=[np.float64, np.float32],
298+
accept_sparse=['csr', 'dok'],
299+
copy=True)
300+
assert_equal(X_checked.dtype, np.float32)
301+
assert_false(X_checked is X)
302+
303+
X_checked = assert_no_warnings(check_array, X_csc_float32,
304+
dtype=[np.float64, np.float32],
305+
accept_sparse=['csr', 'dok'],
306+
copy=False)
307+
assert_equal(X_checked.dtype, np.float32)
308+
assert_false(X_checked is X_csc_float32)
309+
assert_equal(X_checked.format, 'csr')
287310

288311

289312
def test_check_array_min_samples_and_features_messages():

sklearn/utils/validation.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .fixes import astype
1717
from inspect import getargspec
1818

19+
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
20+
1921

2022
class DataConversionWarning(UserWarning):
2123
"""A warning on implicit data conversions happening in the code"""
@@ -233,26 +235,27 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy,
233235
spmatrix_converted : scipy sparse matrix.
234236
Matrix that is ensured to have an allowed type.
235237
"""
236-
if accept_sparse is None:
238+
if accept_sparse in [None, False]:
237239
raise TypeError('A sparse matrix was passed, but dense '
238240
'data is required. Use X.toarray() to '
239241
'convert to a dense numpy array.')
240-
sparse_type = spmatrix.format
241242
if dtype is None:
242243
dtype = spmatrix.dtype
243-
if sparse_type in accept_sparse:
244-
# correct type
245-
if dtype == spmatrix.dtype:
246-
# correct dtype
247-
if copy:
248-
spmatrix = spmatrix.copy()
249-
else:
250-
# convert dtype
251-
spmatrix = spmatrix.astype(dtype)
252-
else:
253-
# create new
244+
245+
changed_format = False
246+
if (isinstance(accept_sparse, (list, tuple))
247+
and spmatrix.format not in accept_sparse):
248+
# create new with correct sparse
254249
spmatrix = spmatrix.asformat(accept_sparse[0])
250+
changed_format = True
251+
252+
if dtype != spmatrix.dtype:
253+
# convert dtype
255254
spmatrix = spmatrix.astype(dtype)
255+
elif copy and not changed_format:
256+
# force copy
257+
spmatrix = spmatrix.copy()
258+
256259
if force_all_finite:
257260
if not hasattr(spmatrix, "data"):
258261
warnings.warn("Can't check %s sparse matrix for nan or inf."
@@ -283,9 +286,11 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
283286
If the input is sparse but not in the allowed format, it will be
284287
converted to the first listed format.
285288
286-
dtype : string, type or None (default="numeric")
289+
dtype : string, type, list of types or None (default="numeric")
287290
Data type of result. If None, the dtype of the input is preserved.
288291
If "numeric", dtype is preserved unless array.dtype is object.
292+
If dtype is a list of types, conversion on the first type is only
293+
performed if the dtype of the input is not in the list.
289294
290295
order : 'F', 'C' or None (default=None)
291296
Whether an array will be forced to be fortran or c-style.
@@ -344,6 +349,15 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
344349
else:
345350
dtype = None
346351

352+
if isinstance(dtype, (list, tuple)):
353+
if dtype_orig is not None and dtype_orig in dtype:
354+
# no dtype conversion required
355+
dtype = None
356+
else:
357+
# dtype conversion required. Let's select the first element of the
358+
# list of accepted types.
359+
dtype = dtype[0]
360+
347361
if sp.issparse(array):
348362
array = _ensure_sparse_format(array, accept_sparse, dtype, copy,
349363
force_all_finite)
@@ -382,7 +396,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None,
382396
if estimator is not None:
383397
if not isinstance(estimator, six.string_types):
384398
estimator = estimator.__class__.__name__
385-
msg += "by %s" % estimator
399+
msg += " by %s" % estimator
386400
warnings.warn(msg, DataConversionWarning)
387401
return array
388402

0 commit comments

Comments
 (0)
0