@@ -257,9 +257,10 @@ def fit(self, X, y, sample_weight=None):
257257 check_is_fitted (self .base_estimator , attributes = ["classes_" ])
258258 self .classes_ = self .base_estimator .classes_
259259
260- pred_method = _get_prediction_method (base_estimator )
260+ pred_method , method_name = _get_prediction_method (base_estimator )
261261 n_classes = len (self .classes_ )
262- predictions = _compute_predictions (pred_method , X , n_classes )
262+ predictions = _compute_predictions (pred_method , method_name , X ,
263+ n_classes )
263264
264265 calibrated_classifier = _fit_calibrator (
265266 base_estimator , predictions , y , self .classes_ , self .method ,
@@ -310,12 +311,13 @@ def fit(self, X, y, sample_weight=None):
310311 )
311312 else :
312313 this_estimator = clone (base_estimator )
313- method_name = _get_prediction_method (this_estimator ). __name__
314+ _ , method_name = _get_prediction_method (this_estimator )
314315 pred_method = partial (
315316 cross_val_predict , estimator = this_estimator , X = X , y = y ,
316317 cv = cv , method = method_name , n_jobs = self .n_jobs
317318 )
318- predictions = _compute_predictions (pred_method , X , n_classes )
319+ predictions = _compute_predictions (pred_method , method_name , X ,
320+ n_classes )
319321
320322 if sample_weight is not None and supports_sw :
321323 this_estimator .fit (X , y , sample_weight )
@@ -441,8 +443,9 @@ def _fit_classifier_calibrator_pair(estimator, X, y, train, test, supports_sw,
441443 estimator .fit (X_train , y_train )
442444
443445 n_classes = len (classes )
444- pred_method = _get_prediction_method (estimator )
445- predictions = _compute_predictions (pred_method , X_test , n_classes )
446+ pred_method , method_name = _get_prediction_method (estimator )
447+ predictions = _compute_predictions (pred_method , method_name , X_test ,
448+ n_classes )
446449
447450 calibrated_classifier = _fit_calibrator (
448451 estimator , predictions , y_test , classes , method , sample_weight = sw_test
@@ -465,18 +468,21 @@ def _get_prediction_method(clf):
465468 -------
466469 prediction_method : callable
467470 The prediction method.
471+ method_name : str
472+ The name of the prediction method.
468473 """
469474 if hasattr (clf , 'decision_function' ):
470475 method = getattr (clf , 'decision_function' )
476+ return method , 'decision_function'
471477 elif hasattr (clf , 'predict_proba' ):
472478 method = getattr (clf , 'predict_proba' )
479+ return method , 'predict_proba'
473480 else :
474481 raise RuntimeError ("'base_estimator' has no 'decision_function' or "
475482 "'predict_proba' method." )
476- return method
477483
478484
479- def _compute_predictions (pred_method , X , n_classes ):
485+ def _compute_predictions (pred_method , method_name , X , n_classes ):
480486 """Return predictions for `X` and reshape binary outputs to shape
481487 (n_samples, 1).
482488
@@ -485,6 +491,9 @@ def _compute_predictions(pred_method, X, n_classes):
485491 pred_method : callable
486492 Prediction method.
487493
494+ method_name: str
495+ Name of the prediction method
496+
488497 X : array-like or None
489498 Data used to obtain predictions.
490499
@@ -498,10 +507,6 @@ def _compute_predictions(pred_method, X, n_classes):
498507 (X.shape[0], 1).
499508 """
500509 predictions = pred_method (X = X )
501- if hasattr (pred_method , '__name__' ):
502- method_name = pred_method .__name__
503- else :
504- method_name = signature (pred_method ).parameters ['method' ].default
505510
506511 if method_name == 'decision_function' :
507512 if predictions .ndim == 1 :
@@ -634,8 +639,9 @@ def predict_proba(self, X):
634639 The predicted probabilities. Can be exact zeros.
635640 """
636641 n_classes = len (self .classes )
637- pred_method = _get_prediction_method (self .base_estimator )
638- predictions = _compute_predictions (pred_method , X , n_classes )
642+ pred_method , method_name = _get_prediction_method (self .base_estimator )
643+ predictions = _compute_predictions (pred_method , method_name , X ,
644+ n_classes )
639645
640646 label_encoder = LabelEncoder ().fit (self .classes )
641647 pos_class_indices = label_encoder .transform (
0 commit comments