8000 ENH Convert y in GradientBoosting to float64 instead of float32 (#13524) · koenvandevelde/scikit-learn@fd1559d · GitHub
[go: up one dir, main page]

Skip to content

Commit fd1559d

Browse files
adrinjalalikoenvandevelde
authored andcommitted
ENH Convert y in GradientBoosting to float64 instead of float32 (scikit-learn#13524)
1 parent 64b00f5 commit fd1559d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from time import time
4545
from ..model_selection import train_test_split
4646
from ..tree.tree import DecisionTreeRegressor
47-
from ..tree._tree import DTYPE
47+
from ..tree._tree import DTYPE, DOUBLE
4848
from ..tree._tree import TREE_LEAF
4949
from . import _gb_losses
5050

@@ -1432,7 +1432,9 @@ def fit(self, X, y, sample_weight=None, monitor=None):
14321432
self._clear_state()
14331433

14341434
# Check input
1435-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], dtype=DTYPE)
1435+
# Since check_array converts both X and y to the same dtype, but the
1436+
# trees use different types for X and y, checking them separately.
1437+
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], dtype=DTYPE)
14361438
n_samples, self.n_features_ = X.shape
14371439

14381440
sample_weight_is_none = sample_weight is None
@@ -1444,6 +1446,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
14441446

14451447
check_consistent_length(X, y, sample_weight)
14461448

1449+
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
1450+
y = column_or_1d(y, warn=True)
14471451
y = self._validate_y(y, sample_weight)
14481452

14491453
if self.n_iter_no_change is not None:
@@ -1722,7 +1726,7 @@ def _validate_y(self, y, sample_weight):
17221726
# consistency with similar method _validate_y of GBC
17231727
self.n_classes_ = 1
17241728
if y.dtype.kind == 'O':
1725-
y = y.astype(np.float64)
1729+
y = y.astype(DOUBLE)
17261730
# Default implementation
17271731
return y
17281732

0 commit comments

Comments
 (0)
{"resolvedServerColorMode":"day"}
0