8000 more robust check for dtype object · scikit-learn/scikit-learn@657a657 · GitHub
[go: up one dir, main page]

Skip to content

Commit 657a657

Browse files
committed
more robust check for dtype object
1 parent cd5166e commit 657a657

File tree

4 files changed

+8
-3
lines changed

4 files changed

+8
-3
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ API changes summary
377377
- `thresh` parameter is deprecated in favor of new `tol` parameter in
378378
:class:`GMM`. See `Enhancements` section for details. By `Hervé Bredin`_.
379379

380+
- Estimators will treat input with dtype object as numeric when possible.
381+
By `Andreas Müller`_
382+
383+
380384

381385
.. _changes_0_15_2:
382386

sklearn/ensemble/gradient_boosting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,7 @@ def feature_importances_(self):
11461146

11471147
def _validate_y(self, y):
11481148
self.n_classes_ = 1
1149-
if y.dtype is np.dtype(object):
1149+
if y.dtype.kind == 'O':
11501150
y = y.astype(np.float64)
11511151
# Default implementation
11521152
return y

sklearn/utils/estimator_checks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def check_estimator_sparse_data(name, Estimator):
151151

152152

153153
def check_dtype_object(name, Estimator):
154+
# check that estimators treat dtype object as numeric if possible
154155
rng = np.random.RandomState(0)
155156
X = rng.rand(40, 10).astype(object)
156157
y = (X[:, 0] * 4).astype(np.int)

sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, copy=Fal
294294
if ensure_2d:
295295
array = np.atleast_2d(array)
296296
if dtype == "numeric":
297-
if getattr(array, "dtype", None) is np.dtype(object):
297+
if hasattr(array, "dtype") and array.dtype.kind == "O":
298298
# if input is object, convert to float.
299299
dtype = np.float64
300300
else:
@@ -398,7 +398,7 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False,
398398
else:
399399
y = column_or_1d(y, warn=True)
400400
_assert_all_finite(y)
401-
if y_numeric and y.dtype is np.dtype(object):
401+
if y_numeric and y.dtype.kind == 'O':
402402
y = y.astype(np.float64)
403403

404404
check_consistent_length(X, y)

0 commit comments

Comments
 (0)
0