8000 FIX Raise an error when min_samples_split=1 in trees (#25744) · scikit-learn/scikit-learn@0cae7df · GitHub
[go: up one dir, main page]

Skip to content

Commit 0cae7df

Browse files
committed
FIX Raise an error when min_samples_split=1 in trees (#25744)
1 parent 12f1675 commit 0cae7df

File tree

5 files changed

+68
-13
lines changed

5 files changed

+68
-13
lines changed

doc/whats_new/v1.2.rst

+9
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ Changelog
7979
`encoded_missing_value` or `unknown_value` set to a categories' cardinality
8080
when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_.
8181

82+
:mod:`sklearn.tree`
83+
...................
84+
85+
- |Fix| Fixed a regression in :class:`tree.DecisionTreeClassifier`,
86+
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and
87+
:class:`tree.ExtraTreeRegressor` where an error was no longer raised in version
88+
1.2 when `min_sample_split=1`.
89+
:pr:`25744` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
90+
8291
:mod:`sklearn.utils`
8392
....................
8493

sklearn/tree/_classes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
9999
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
100100
"min_samples_split": [
101101
Interval(Integral, 2, None, closed="left"),
102-
Interval(Real, 0.0, 1.0, closed="right"),
102+
Interval("real_not_int", 0.0, 1.0, closed="right"),
103103
],
104104
"min_samples_leaf": [
105105
Interval(Integral, 1, None, closed="left"),
106-
Interval(Real, 0.0, 1.0, closed="neither"),
106+
Interval("real_not_int", 0.0, 1.0, closed="neither"),
107107
],
108108
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
109109
"max_features": [
110110
Interval(Integral, 1, None, closed="left"),
111-
Interval(Real, 0.0, 1.0, closed="right"),
111+
Interval("real_not_int", 0.0, 1.0, closed="right"),
112112
StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}),
113113
None,
114114
],

sklearn/tree/tests/test_tree.py

+22
Original file line numberDiff line numberDiff line change
@@ -2406,3 +2406,25 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir):
24062406
clf.tree_,
24072407
"The trees of the original and loaded classifiers are not equal.",
24082408
)
2409+
2410+
2411+
@pytest.mark.parametrize("Tree", ALL_TREES.values())
2412+
def test_min_sample_split_1_error(Tree):
2413+
"""Check that an error is raised when min_sample_split=1.
2414+
2415+
non-regression test for issue gh-25481.
2416+
"""
2417+
X = np.array([[0, 0], [1, 1]])
2418+
y = np.array([0, 1])
2419+
2420+
# min_samples_split=1.0 is valid
2421+
Tree(min_samples_split=1.0).fit(X, y)
2422+
2423+
# min_samples_split=1 is invalid
2424+
tree = Tree(min_samples_split=1)
2425+
msg = (
2426+
r"'min_samples_split' .* must be an int in the range \[2, inf\) "
2427+
r"or a float in the range \(0.0, 1.0\]"
2428+
)
2429+
with pytest.raises(ValueError, match=msg):
2430+
tree.fit(X, y)

sklearn/utils/_param_validation.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,12 @@ class Interval(_Constraint):
364364
365365
Parameters
366366
----------
367-
type : {numbers.Integral, numbers.Real}
367+
type : {numbers.Integral, numbers.Real, "real_not_int"}
368368
The set of numbers in which to set the interval.
369369
370+
If "real_not_int", only reals that don't have the integer type
371+
are allowed. For example 1.0 is allowed but 1 is not.
372+
370373
left : float or int or None
371374
The left bound of the interval. None means left bound is -∞.
372375
@@ -392,14 +395,6 @@ class Interval(_Constraint):
392395
`[0, +∞) U {+∞}`.
393396
"""
394397

395-
@validate_params(
396-
{
397-
"type": [type],
398-
"left": [Integral, Real, None],
399-
"right": [Integral, Real, None],
400-
"closed": [StrOptions({"left", "right", "both", "neither"})],
401-
}
402-
)
403398
def __init__(self, type, left, right, *, closed):
404399
super().__init__()
405400
self.type = type
@@ -410,6 +405,18 @@ def __init__(self, type, left, right, *, closed):
410405
self._check_params()
411406

412407
def _check_params(self):
408+
if self.type not in (Integral, Real, "real_not_int"):
409+
raise ValueError(
410+
"type must be either numbers.Integral, numbers.Real or 'real_not_int'."
411+
f" Got {self.type} instead."
412+
)
413+
414+
if self.closed not in ("left", "right", "both", "neither"):
415+
raise ValueError(
416+
"closed must be either 'left', 'right', 'both' or 'neither'. "
417+
f"Got {self.closed} instead."
418+
)
419+
413420
if self.type is Integral:
414421
suffix = "for an interval over the integers."
415422
if self.left is not None and not isinstance(self.left, Integral):
@@ -424,6 +431,11 @@ def _check_params(self):
424431
raise ValueError(
425432
f"right can't be None when closed == {self.closed} {suffix}"
426433
)
434+
else:
435+
if self.left is not None and not isinstance(self.left, Real):
436+
raise TypeError("Expecting left to be a real number.")
437+
if self.right is not None and not isinstance(self.right, Real):
438+
raise TypeError("Expecting right to be a real number.")
427439

428440
if self.right is not None and self.left is not None and self.right <= self.left:
429441
raise ValueError(
@@ -447,8 +459,13 @@ def __contains__(self, val):
447459
return False
448460
return True
449461

462+
def _has_valid_type(self, val):
463+
if self.type == "real_not_int":
464+
return isinstance(val, Real) and not isinstance(val, Integral)
465+
return isinstance(val, self.type)
466+
450467
def is_satisfied_by(self, val):
451-
if not isinstance(val, self.type):
468+
if not self._has_valid_type(val):
452469
return False
453470

454471
return val in self

sklearn/utils/tests/test_param_validation.py

+7
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,10 @@ def fit(self, X=None, y=None):
662662
# does not raise, even though "b" is not in the constraints dict and "a" is not
663663
# a parameter of the estimator.
664664
ThirdPartyEstimator(b=0).fit()
665+
666+
667+
def test_interval_real_not_int():
668+
"""Check for the type "real_not_int" in the Interval constraint."""
669+
constraint = Interval("real_not_int", 0, 1, closed="both")
670+
assert constraint.is_satisfied_by(1.0)
671+
assert not constraint.is_satisfied_by(1)

0 commit comments

Comments
 (0)
0