@@ -257,9 +257,10 @@ def fit(self, X, y, sample_weight=None):
257
257
check_is_fitted (self .base_estimator , attributes = ["classes_" ])
258
258
self .classes_ = self .base_estimator .classes_
259
259
260
- pred_method = _get_prediction_method (base_estimator )
260
+ pred_method , method_name = _get_prediction_method (base_estimator )
261
261
n_classes = len (self .classes_ )
262
- predictions = _compute_predictions (pred_method , X , n_classes )
262
+ predictions = _compute_predictions (pred_method , method_name , X ,
263
+ n_classes )
263
264
264
265
calibrated_classifier = _fit_calibrator (
265
266
base_estimator , predictions , y , self .classes_ , self .method ,
@@ -310,12 +311,13 @@ def fit(self, X, y, sample_weight=None):
310
311
)
311
312
else :
312
313
this_estimator = clone (base_estimator )
313
- method_name = _get_prediction_method (this_estimator ). __name__
314
+ _ , method_name = _get_prediction_method (this_estimator )
314
315
pred_method = partial (
315
316
cross_val_predict , estimator = this_estimator , X = X , y = y ,
316
317
cv = cv , method = method_name , n_jobs = self .n_jobs
317
318
)
318
- predictions = _compute_predictions (pred_method , X , n_classes )
319
+ predictions = _compute_predictions (pred_method , method_name , X ,
320
+ n_classes )
319
321
320
322
if sample_weight is not None and supports_sw :
321
323
this_estimator .fit (X , y , sample_weight )
@@ -441,8 +443,9 @@ def _fit_classifier_calibrator_pair(estimator, X, y, train, test, supports_sw,
441
443
estimator .fit (X_train , y_train )
442
444
443
445
n_classes = len (classes )
444
- pred_method = _get_prediction_method (estimator )
445
- predictions = _compute_predictions (pred_method , X_test , n_classes )
446
+ pred_method , method_name = _get_prediction_method (estimator )
447
+ predictions = _compute_predictions (pred_method , method_name , X_test ,
448
+ n_classes )
446
449
447
450
calibrated_classifier = _fit_calibrator (
448
451
estimator , predictions , y_test , classes , method , sample_weight = sw_test
@@ -465,18 +468,21 @@ def _get_prediction_method(clf):
465
468
-------
466
469
prediction_method : callable
467
470
The prediction method.
471
+ method_name : str
472
+ The name of the prediction method.
468
473
"""
469
474
if hasattr (clf , 'decision_function' ):
470
475
method = getattr (clf , 'decision_function' )
476
+ return method , 'decision_function'
471
477
elif hasattr (clf , 'predict_proba' ):
472
478
method = getattr (clf , 'predict_proba' )
479
+ return method , 'predict_proba'
473
480
else :
474
481
raise RuntimeError ("'base_estimator' has no 'decision_function' or "
475
482
"'predict_proba' method." )
476
- return method
477
483
478
484
479
- def _compute_predictions (pred_method , X , n_classes ):
485
+ def _compute_predictions (pred_method , method_name , X , n_classes ):
480
486
"""Return predictions for `X` and reshape binary outputs to shape
481
487
(n_samples, 1).
482
488
@@ -485,6 +491,9 @@ def _compute_predictions(pred_method, X, n_classes):
485
491
pred_method : callable
486
492
Prediction method.
487
493
494
+ method_name: str
495
+ Name of the prediction method
496
+
488
497
X : array-like or None
489
498
Data used to obtain predictions.
490
499
@@ -498,10 +507,6 @@ def _compute_predictions(pred_method, X, n_classes):
498
507
(X.shape[0], 1).
499
508
"""
500
509
predictions = pred_method (X = X )
501
- if hasattr (pred_method , '__name__' ):
502
- method_name = pred_method .__name__
503
- else :
504
- method_name = signature (pred_method ).parameters ['method' ].default
505
510
506
511
if method_name == 'decision_function' :
507
512
if predictions .ndim == 1 :
@@ -634,8 +639,9 @@ def predict_proba(self, X):
634
639
The predicted probabilities. Can be exact zeros.
635
640
"""
636
641
n_classes = len (self .classes )
637
- pred_method = _get_prediction_method (self .base_estimator )
638
- predictions = _compute_predictions (pred_method , X , n_classes )
642
+ pred_method , method_name = _get_prediction_method (self .base_estimator )
643
+ predictions = _compute_predictions (pred_method , method_name , X ,
644
+ n_classes )
639
645
640
646
label_encoder = LabelEncoder ().fit (self .classes )
641
647
pos_class_indices = label_encoder .transform (
0 commit comments