8000 MAINT Parameters validation for sklearn.calibration.calibration_curve… · scikit-learn/scikit-learn@1501237 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1501237

Browse files
MAINT Parameters validation for sklearn.calibration.calibration_curve (#26198)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent c2f9d9a commit 1501237

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sklearn/calibration.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# License: BSD 3 clause
99

10-
from numbers import Integral
10+
from numbers import Integral, Real
1111
import warnings
1212
from inspect import signature
1313
from functools import partial
@@ -35,7 +35,13 @@
3535

3636
from .utils.multiclass import check_classification_targets
3737
from .utils.parallel import delayed, Parallel
38-
from .utils._param_validation import StrOptions, HasMethods, Hidden
38+
from .utils._param_validation import (
39+
StrOptions,
40+
HasMethods,
41+
Hidden,
42+
validate_params,
43+
Interval,
44+
)
3945
from .utils._plotting import _BinaryClassifierCurveDisplayMixin
4046
from .utils.validation import (
4147
_check_fit_params,
@@ -903,6 +909,15 @@ def predict(self, T):
903909
return expit(-(self.a_ * T + self.b_))
904910

905911

912+
@validate_params(
913+
{
914+
"y_true": ["array-like"],
915+
"y_prob": ["array-like"],
916+
"pos_label": [Real, str, "boolean", None],
917+
"n_bins": [Interval(Integral, 1, None, closed="left")],
918+
"strategy": [StrOptions({"uniform", "quantile"})],
919+
}
920+
)
906921
def calibration_curve(
907922
y_true,
908923
y_prob,

sklearn/tests/test_public_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _check_function_param_validation(
110110

111111

112112
PARAM_VALIDATION_FUNCTION_LIST = [
113+
"sklearn.calibration.calibration_curve",
113114
"sklearn.cluster.cluster_optics_dbscan",
114115
"sklearn.cluster.compute_optics_graph",
115116
"sklearn.cluster.estimate_bandwidth",

0 commit comments

Comments
 (0)
0