99import numpy as np
1010from abc import ABCMeta , abstractmethod
1111from warnings import warn
12+ from functools import partial
1213
1314from joblib import Parallel
1415
@@ -68,7 +69,15 @@ def _generate_bagging_indices(
6869
6970
7071def _parallel_build_estimators (
71- n_estimators , ensemble , X , y , sample_weight , seeds , total_n_estimators , verbose
72+ n_estimators ,
73+ ensemble ,
74+ X ,
75+ y ,
76+ sample_weight ,
77+ seeds ,
78+ total_n_estimators ,
79+ verbose ,
80+ check_input ,
7281):
7382 """Private function used to build a batch of estimators within a job."""
7483 # Retrieve settings
@@ -78,6 +87,7 @@ def _parallel_build_estimators(
7887 bootstrap = ensemble .bootstrap
7988 bootstrap_features = ensemble .bootstrap_features
8089 support_sample_weight = has_fit_parameter (ensemble .base_estimator_ , "sample_weight" )
90+ has_check_input = has_fit_parameter (ensemble .base_estimator_ , "check_input" )
8191 if not support_sample_weight and sample_weight is not None :
8292 raise ValueError ("The base estimator doesn't support sample weight" )
8393
@@ -95,6 +105,11 @@ def _parallel_build_estimators(
95105 random_state = seeds [i ]
96106 estimator = ensemble ._make_estimator (append = False , random_state = random_state )
97107
108+ if has_check_input :
109+ estimator_fit = partial (estimator .fit , check_input = check_input )
110+ else :
111+ estimator_fit = estimator .fit
112+
98113 # Draw random feature, sample indices
99114 features , indices = _generate_bagging_indices (
100115 random_state ,
@@ -120,10 +135,10 @@ def _parallel_build_estimators(
120135 not_indices_mask = ~ indices_to_mask (indices , n_samples )
121136 curr_sample_weight [not_indices_mask ] = 0
122137
123- estimator . fit (X [:, features ], y , sample_weight = curr_sample_weight )
138+ estimator_fit (X [:, features ], y , sample_weight = curr_sample_weight )
124139
125140 else :
126- estimator . fit (( X [indices ]) [:, features ], y [indices ])
141+ estimator_fit ( X [indices ][:, features ], y [indices ])
127142
128143 estimators .append (estimator )
129144 estimators_features .append (features )
@@ -284,7 +299,15 @@ def fit(self, X, y, sample_weight=None):
284299 def _parallel_args (self ):
285300 return {}
286301
287- def _fit (self , X , y , max_samples = None , max_depth = None , sample_weight = None ):
302+ def _fit (
303+ self ,
304+ X ,
305+ y ,
306+ max_samples = None ,
307+ max_depth = None ,
308+ sample_weight = None ,
309+ check_input = True ,
310+ ):
288311 """Build a Bagging ensemble of estimators from the training
289312 set (X, y).
290313
@@ -310,6 +333,10 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
310333 Note that this is supported only if the base estimator supports
311334 sample weighting.
312335
336+ check_input : bool, default=True
337+ Override value used when fitting base estimator. Only supported
338+ if the base estimator has a check_input parameter for fit function.
339+
313340 Returns
314341 -------
315342 self : object
@@ -416,6 +443
8000
,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
416443 seeds [starts [i ] : starts [i + 1 ]],
417444 total_n_estimators ,
418445 verbose = self .verbose ,
446+ check_input = check_input ,
419447 )
420448 for i in range (n_jobs )
421449 )
0 commit comments