8000 Make CT accept list of mixed typed objects including nans without con… · scikit-learn/scikit-learn@864c2cc · GitHub
[go: up one dir, main page]

Skip to content

Commit 864c2cc

Browse files
committed
Make CT accept list of mixed typed objects including nans without conversions
1 parent 1fc61a5 commit 864c2cc

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

sklearn/compose/_column_transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from ..base import clone, TransformerMixin
1717
from ..utils import Parallel, delayed
1818
from ..externals import six
19-
from ..pipeline import (
20-
_fit_one_transformer, _fit_transform_one, _transform_one, _name_estimators)
19+
from ..pipeline import _fit_transform_one, _transform_one, _name_estimators
2120
from ..preprocessing import FunctionTransformer
2221
from ..utils import Bunch
2322
from ..utils.metaestimators import _BaseComposition
@@ -517,7 +516,7 @@ def _check_X(X):
517516
"""Use check_array only on lists and other non-array-likes / sparse"""
518517
if hasattr(X, '__array__') or sparse.issparse(X):
519518
return X
520-
return check_array(X)
519+
return check_array(X, force_all_finite='allow-nan', dtype=np.object)
521520

522521

523522
def _check_key_type(key, superclass):

sklearn/compose/tests/test_column_transformer.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,24 @@ def test_column_transformer_sparse_array():
279279

280280

281281
def test_column_transformer_list():
282-
X_list = [[1, 2, 3]]
283-
X_res = np.array([[0, 0, 0]])
284-
285-
ct = ColumnTransformer([('trans1', StandardScaler(), [0, 1, 2])])
286-
287-
assert_array_equal(ct.fit_transform(X_list), X_res)
288-
assert_array_equal(ct.fit(X_list).transform(X_list), X_res)
282+
X_list = [
283+
[1, float('nan'), 'a'],
284+
[0, 0, 'b']
285+
]
286+
expected_result = np.array([
287+
[1, float('nan'), 1, 0],
288+
[-1, 0, 0, 1],
289+
])
290+
291+
ct = ColumnTransformer([
292+
('numerical', StandardScaler(), [0, 1]),
293+
('categorical', OneHotEncoder(), [2]),
294+
])
295+
296+
with pytest.warns(None) as record:
297+
assert_array_equal(ct.fit_transform(X_list), expected_result)
298+
assert_array_equal(ct.fit(X_list).transform(X_list), expected_result)
299+
assert [w.message for w in record.list] == []
289300

290301

291302
def test_column_transformer_sparse_stacking():

sklearn/utils/validation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
438438
warn_on_dtype : boolean (default=False)
439439
Raise DataConversionWarning if the dtype of the input data structure
440440
does not match the requested dtype, causing a memory copy.
441+
This warning is not raised if the original data has an object dtype.
441442
442443
estimator : str or estimator instance (default=None)
443444
If passed, include the name of the estimator in warning messages.
@@ -584,7 +585,9 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
584585
% (n_features, shape_repr, ensure_min_features,
585586
context))
586587

587-
if warn_on_dtype and dtype_orig is not None and array.dtype != dtype_orig:
588+
if (warn_on_dtype and dtype_orig is not None
589+
and dtype_orig.kind != 'O'
590+
and array.dtype != dtype_orig):
588591
msg = ("Data with input dtype %s was converted to %s%s."
589592
% (dtype_orig, array.dtype, context))
590593
warnings.warn(msg, DataConversionWarning)

0 commit comments

Comments
 (0)
0