8000 Use RealNotInt for parameters that accept ints and floats · sebp/scikit-survival@0e9ca7e · GitHub
[go: up one dir, main page]

Skip to content

Commit 0e9ca7e

Browse files
committed
Use RealNotInt for parameters that accept ints and floats
See scikit-learn/scikit-learn#25797
1 parent e128071 commit 0e9ca7e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sksurv/tree/tree.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.tree._splitter import Splitter
1010
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
1111
from sklearn.tree._utils import _any_isnan_axis0
12-
from sklearn.utils._param_validation import Interval, StrOptions
12+
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
1313
from sklearn.utils.validation import (
1414
_assert_all_finite_element_wise,
1515
_check_n_features,
@@ -161,16 +161,16 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
161161
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
162162
"min_samples_split": [
163163
Interval(Integral, 2, None, closed="left"),
164-
Interval(Real, 0.0, 1.0, closed="neither"),
164+
Interval(RealNotInt, 0.0, 1.0, closed="neither"),
165165
],
166166
"min_samples_leaf": [
167167
Interval(Integral, 1, None, closed="left"),
168-
Interval(Real, 0.0, 0.5, closed="right"),
168+
Interval(RealNotInt, 0.0, 0.5, closed="right"),
169169
],
170170
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
171171
"max_features": [
172172
Interval(Integral, 1, None, closed="left"),
173-
Interval(Real, 0.0, 1.0, closed="right"),
173+
Interval(RealNotInt, 0.0, 1.0, closed="right"),
174174
StrOptions({"sqrt", "log2"}),
175175
None,
176176
],
@@ -363,7 +363,7 @@ def _check_params(self, n_samples):
363363

364364
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
365365

366-
if isinstance(self.min_samples_leaf, (Integral, np.integer)):
366+
if isinstance(self.min_samples_leaf, Integral):
367367
min_samples_leaf = self.min_samples_leaf
368368
else: # float
369369
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
@@ -397,7 +397,7 @@ def _check_max_features(self):
397397

398398
elif self.max_features is None:
399399
max_features = self.n_features_in_
400-
elif isinstance(self.max_features, (Integral, np.integer)):
400+
elif isinstance(self.max_features, Integral):
401401
max_features = self.max_features
402402
else: # float
403403
if self.max_features > 0.0:

0 commit comments

Comments
 (0)
0