From 5edd6aae318774be2d14e7433a931e405cb6d8cb Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Thu, 2 Mar 2023 17:54:27 +0100 Subject: [PATCH 1/4] add real_not_int type in Interval constraint to avoid overlaps. --- doc/whats_new/v1.2.rst | 9 ++++++ sklearn/tree/_classes.py | 6 ++-- sklearn/tree/tests/test_tree.py | 22 +++++++++++++++ sklearn/utils/_param_validation.py | 29 +++++++++++++------- sklearn/utils/tests/test_param_validation.py | 7 +++++ 5 files changed, 60 insertions(+), 13 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index c252a7c18f5c9..fbc99f3cbc80d 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -79,6 +79,15 @@ Changelog `encoded_missing_value` or `unknown_value` set to a categories' cardinality when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_. +:mod:`sklearn.tree` +................... + +- |Fix| Fixed a regression in :class:`tree.DecisionTreeClassifier`, + :class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and + :class:`tree.ExtraTreeRegressor` where an error was no longer raised in version + 1.2 when `min_sample_split=1`. + :pr:`25744` by :user:`Jérémie du Boisberranger `. + :mod:`sklearn.utils` .................... diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e2e41f9aea78b..6e01b8b49e594 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -99,16 +99,16 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "max_depth": [Interval(Integral, 1, None, closed="left"), None], "min_samples_split": [ Interval(Integral, 2, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="right"), + Interval("real_not_int", 0.0, 1.0, closed="right"), ], "min_samples_leaf": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="neither"), + Interval("real_not_int", 0.0, 1.0, closed="neither"), ], "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], "max_features": [ Interval(Integral, 1, None, closed="left"), - Interval(Real, 0.0, 1.0, closed="right"), + Interval("real_not_int", 0.0, 1.0, closed="right"), StrOptions({"auto", "sqrt", "log2"}, deprecated={"auto"}), None, ], diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 9b1a29f02ead7..c796177ad814c 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2425,3 +2425,25 @@ def test_tree_deserialization_from_read_only_buffer(tmpdir): clf.tree_, "The trees of the original and loaded classifiers are not equal.", ) + + +@pytest.mark.parametrize("Tree", ALL_TREES.values()) +def test_min_sample_split_1_error(Tree): + """Check that an error is raised when min_sample_split=1. + + non-regression test for issue gh-25481. + """ + X = np.array([[0, 0], [1, 1]]) + y = np.array([0, 1]) + + # min_samples_split=1.0 is valid + Tree(min_samples_split=1.0).fit(X, y) + + # min_samples_split=1 is invalid + tree = Tree(min_samples_split=1) + msg = ( + r"'min_samples_split' .* must be an int in the range \[2, inf\) " + r"or a float in the range \(0.0, 1.0\]" + ) + with pytest.raises(ValueError, match=msg): + tree.fit(X, y) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index aa8906071c6af..e1c5a43e6c6a5 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -364,7 +364,7 @@ class Interval(_Constraint): Parameters ---------- - type : {numbers.Integral, numbers.Real} + type : {numbers.Integral, numbers.Real, "real_not_int"} The set of numbers in which to set the interval. left : float or int or None @@ -392,14 +392,6 @@ class Interval(_Constraint): `[0, +∞) U {+∞}`. """ - @validate_params( - { - "type": [type], - "left": [Integral, Real, None], - "right": [Integral, Real, None], - "closed": [StrOptions({"left", "right", "both", "neither"})], - } - ) def __init__(self, type, left, right, *, closed): super().__init__() self.type = type @@ -410,6 +402,18 @@ def __init__(self, type, left, right, *, closed): self._check_params() def _check_params(self): + if self.type not in (Integral, Real, "real_not_int"): + raise ValueError( + "type must be either numbers.Integral, numbers.Real or 'real_not_int'." + f" Got {self.type} instead." + ) + + if self.closed not in ("left", "right", "both", "neither"): + raise ValueError( + "closed must be either 'left', 'right', 'both' or 'neither'. " + f"Got {self.closed} instead." + ) + if self.type is Integral: suffix = "for an interval over the integers." if self.left is not None and not isinstance(self.left, Integral): @@ -447,8 +451,13 @@ def __contains__(self, val): return False return True + def _has_valid_type(self, val): + if self.type == "real_not_int": + return isinstance(val, Real) and not isinstance(val, Integral) + return isinstance(val, self.type) + def is_satisfied_by(self, val): - if not isinstance(val, self.type): + if not self._has_valid_type(val): return False return val in self diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 85cd06d0f38b8..ce8f9cdf939fd 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -662,3 +662,10 @@ def fit(self, X=None, y=None): # does not raise, even though "b" is not in the constraints dict and "a" is not # a parameter of the estimator. ThirdPartyEstimator(b=0).fit() + + +def test_interval_real_not_int(): + """Check for the type "real_not_int" in the Interval constraint.""" + constraint = Interval("real_not_int", 0, 1, closed="both") + assert constraint.is_satisfied_by(1.0) + assert not constraint.is_satisfied_by(1) From 746b8fee29b5afbc328045240e8605bbf162ee2d Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Fri, 3 Mar 2023 11:59:32 +0100 Subject: [PATCH 2/4] improve description --- sklearn/utils/_param_validation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index e1c5a43e6c6a5..50bfd7d77ceea 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -367,6 +367,9 @@ class Interval(_Constraint): type : {numbers.Integral, numbers.Real, "real_not_int"} The set of numbers in which to set the interval. + If "real_not_int", only reals that don't have the integer type + are allowed. For example 1.0 is allowed but 1 is not. + left : float or int or None The left bound of the interval. None means left bound is -∞. From de48998957dbada7b1aff9a60cbccb3cf6f72ff4 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Mon, 6 Mar 2023 12:29:41 +0100 Subject: [PATCH 3/4] add validation for left and right --- sklearn/utils/_param_validation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 50bfd7d77ceea..43e335161e940 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -431,6 +431,11 @@ def _check_params(self): raise ValueError( f"right can't be None when closed == {self.closed} {suffix}" ) + else: + if self.left is not None and not isinstance(self.left, Real): + raise TypeError(f"Expecting left to be a real number.") + if self.right is not None and not isinstance(self.right, Real): + raise TypeError(f"Expecting right to be a real number.") if self.right is not None and self.left is not None and self.right <= self.left: raise ValueError( From a4f53ff4ef8653de64a11afdc53d9496d656446f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:37:24 +0100 Subject: [PATCH 4/4] lint --- sklearn/utils/_param_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 43e335161e940..8d23f0b23b6eb 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -433,9 +433,9 @@ def _check_params(self): ) else: if self.left is not None and not isinstance(self.left, Real): - raise TypeError(f"Expecting left to be a real number.") + raise TypeError("Expecting left to be a real number.") if self.right is not None and not isinstance(self.right, Real): - raise TypeError(f"Expecting right to be a real number.") + raise TypeError("Expecting right to be a real number.") if self.right is not None and self.left is not None and self.right <= self.left: raise ValueError(