8000 FIX Use joblib Parallel for the initial binning in HGBT (#29386) · scikit-learn/scikit-learn@82404ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 82404ba

Browse files
OmarManzoorlesteve
andauthored
FIX Use joblib Parallel for the initial binning in HGBT (#29386)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 012de1e commit 82404ba

File tree

1 file changed

+11
-13
lines changed
  • sklearn/ensemble/_hist_gradient_boosting

1 file changed

+11
-13
lines changed

sklearn/ensemble/_hist_gradient_boosting/binning.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
"""
88

99
# Author: Nicolas Hug
10-
import concurrent.futures
1110

1211
import numpy as np
1312

1413
from ...base import BaseEstimator, TransformerMixin
1514
from ...utils import check_array, check_random_state
1615
from ...utils._openmp_helpers import _openmp_effective_n_threads
1716
from ...utils.fixes import percentile
17+
from ...utils.parallel import Parallel, delayed
1818
from ...utils.validation import check_is_fitted
1919
from ._binning import _map_to_bins
2020
from ._bitset import set_bitset_memoryview
@@ -230,19 +230,13 @@ def fit(self, X, y=None):
230230
self.bin_thresholds_ = [None] * n_features
231231
n_bins_non_missing = [None] * n_features
232232

233-
with concurrent.futures.ThreadPoolExecutor(
234-
max_workers=self.n_threads
235-
) as executor:
236-
future_to_f_idx = {
237-
executor.submit(_find_binning_thresholds, X[:, f_idx], max_bins): f_idx
238-
for f_idx in range(n_features)
239-
if not self.is_categorical_[f_idx]
240-
}
241-
for future in concurrent.futures.as_completed(future_to_f_idx):
242-
f_idx = future_to_f_idx[future]
243-
self.bin_thresholds_[f_idx] = future.result()
244-
n_bins_non_missing[f_idx] = self.bin_thresholds_[f_idx].shape[0] + 1
233+
non_cat_thresholds = Parallel(n_jobs=self.n_threads, backend="threading")(
234+
delayed(_find_binning_thresholds)(X[:, f_idx], max_bins)
235+
for f_idx in range(n_features)
236+
if not self.is_categorical_[f_idx]
237+
)
245238

239+
non_cat_idx = 0
246240
for f_idx in range(n_features):
247241
if self.is_categorical_[f_idx]:
248242
# Since categories are assumed to be encoded in
@@ -252,6 +246,10 @@ def fit(self, X, y=None):
252246
thresholds = known_categories[f_idx]
253247
n_bins_non_missing[f_idx] = thresholds.shape[0]
254248
self.bin_thresholds_[f_idx] = thresholds
249+
else:
250+
self.bin_thresholds_[f_idx] = non_cat_thresholds[non_cat_idx]
251+
n_bins_non_missing[f_idx] = self.bin_thresholds_[f_idx].shape[0] + 1
252+
non_cat_idx += 1
255253

256254
self.n_bins_non_missing_ = np.array(n_bins_non_missing, dtype=np.uint32)
257255
return self

0 commit comments

Comments
 (0)
0