8000 TST invalid init parameters for losses (#22407) · scikit-learn/scikit-learn@998e8f2 · GitHub
[go: up one dir, main page]

Skip to content
Sign in

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 998e8f2

Browse files
authored
TST invalid init parameters for losses (#22407)
1 parent 4d9e005 commit 998e8f2

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

sklearn/_loss/loss.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# - SGDRegressor, SGDClassifier
1616
# - Replace link module of GLMs.
1717

18+
import numbers
1819
import numpy as np
1920
from scipy.special import xlogy
2021
from ._loss import (
@@ -34,6 +35,7 @@
3435
LogitLink,
3536
MultinomialLogit,
3637
)
38+
from ..utils import check_scalar
3739
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
3840
from ..utils.stats import _weighted_percentile
3941

@@ -604,11 +606,14 @@ class PinballLoss(BaseLoss):
604606
need_update_leaves_values = True
605607

606608
def __init__(self, sample_weight=None, quantile=0.5):
607-
if quantile <= 0 or quantile >= 1:
608-
raise ValueError(
609-
"PinballLoss aka quantile loss only accepts "
610-
f"0 < quantile < 1; {quantile} was given."
611-
)
609+
check_scalar(
610+
quantile,
611+
"quantile",
612+
target_type=numbers.Real,
613+
min_val=0,
614+
max_val=1,
615+
include_boundaries="neither",
616+
)
612617
super().__init__(
613618
closs=CyPinballLoss(quantile=float(quantile)),
614619
link=IdentityLink(),
@@ -725,6 +730,14 @@ class HalfTweedieLoss(BaseLoss):
725730
"""
726731

727732
def __init__(self, sample_weight=None, power=1.5):
733+
check_scalar(
734+
power,
735+
"power",
736+
target_type=numbers.Real,
737+
include_boundaries="neither",
738+
min_val=-np.inf,
739+
max_val=np.inf,
740+
)
728741
super().__init__(
729742
closs=CyHalfTweedieLoss(power=float(power)),
730743
link=LogLink(),

sklearn/_loss/tests/test_loss.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,42 @@ def test_init_gradient_and_hessian_raises(loss, params, err_msg):
10481048
gradient, hessian = loss.init_gradient_and_hessian(n_samples=5, **params)
10491049

10501050

1051+
@pytest.mark.parametrize(
1052+
"loss, params, err_type, err_msg",
1053+
[
1054+
(
1055+
PinballLoss,
1056+
{"quantile": None},
1057+
TypeError,
1058+
"quantile must be an instance of float, not NoneType.",
1059+
),
1060+
(
1061+
PinballLoss,
1062+
{"quantile": 0},
1063+
ValueError,
1064+
"quantile == 0, must be > 0.",
1065+
),
1066+
(PinballLoss, {"quantile": 1.1}, ValueError, "quantile == 1.1, must be < 1."),
1067+
(
1068+
HalfTweedieLoss,
1069+
{"power": None},
1070+
TypeError,
1071+
"power must be an instance of float, not NoneType.",
1072+
),
1073+
(
1074+
HalfTweedieLoss,
1075+
{"power": np.inf},
1076+
ValueError,
1077+
"power == inf, must be < inf.",
1078+
),
1079+
],
1080+
)
1081+
def test_loss_init_parameter_validation(loss, params, err_type, err_msg):
1082+
"""Test that loss raises errors for invalid input."""
1083+
with pytest.raises(err_type, match=err_msg):
1084+
loss(**params)
1085+
1086+
10511087
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
10521088
def test_loss_pickle(loss):
10531089
"""Test that losses can be pickled."""

0 commit comments

Comments
 (0)
0