|
9 | 9 | from sklearn.tree._splitter import Splitter
|
10 | 10 | from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
|
11 | 11 | from sklearn.tree._utils import _any_isnan_axis0
|
12 |
| -from sklearn.utils._param_validation import Interval, StrOptions |
| 12 | +from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions |
13 | 13 | from sklearn.utils.validation import (
|
14 | 14 | _assert_all_finite_element_wise,
|
15 | 15 | _check_n_features,
|
@@ -161,16 +161,16 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
161 | 161 | "max_depth": [Interval(Integral, 1, None, closed="left"), None],
|
162 | 162 | "min_samples_split": [
|
163 | 163 | Interval(Integral, 2, None, closed="left"),
|
164 |
| - Interval(Real, 0.0, 1.0, closed="neither"), |
| 164 | + Interval(RealNotInt, 0.0, 1.0, closed="neither"), |
165 | 165 | ],
|
166 | 166 | "min_samples_leaf": [
|
167 | 167 | Interval(Integral, 1, None, closed="left"),
|
168 |
| - Interval(Real, 0.0, 0.5, closed="right"), |
| 168 | + Interval(RealNotInt, 0.0, 0.5, closed="right"), |
169 | 169 | ],
|
170 | 170 | "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
|
171 | 171 | "max_features": [
|
172 | 172 | Interval(Integral, 1, None, closed="left"),
|
173 |
| - Interval(Real, 0.0, 1.0, closed="right"), |
| 173 | + Interval(RealNotInt, 0.0, 1.0, closed="right"), |
174 | 174 | StrOptions({"sqrt", "log2"}),
|
175 | 175 | None,
|
176 | 176 | ],
|
@@ -363,7 +363,7 @@ def _check_params(self, n_samples):
|
363 | 363 |
|
364 | 364 | max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
|
365 | 365 |
|
366 |
| - if isinstance(self.min_samples_leaf, (Integral, np.integer)): |
| 366 | + if isinstance(self.min_samples_leaf, Integral): |
367 | 367 | min_samples_leaf = self.min_samples_leaf
|
368 | 368 | else: # float
|
369 | 369 | min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
|
@@ -397,7 +397,7 @@ def _check_max_features(self):
|
397 | 397 |
|
398 | 398 | elif self.max_features is None:
|
399 | 399 | max_features = self.n_features_in_
|
400 |
| - elif isinstance(self.max_features, (Integral, np.integer)): |
| 400 | + elif isinstance(self.max_features, Integral): |
401 | 401 | max_features = self.max_features
|
402 | 402 | else: # float
|
403 | 403 | if self.max_features > 0.0:
|
|
0 commit comments