8000 MAINT Parameter validation for tree.export_text (#25867) · scikit-learn/scikit-learn@54108d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 54108d9

Browse files
MAINT Parameter validation for tree.export_text (#25867)
1 parent 9260f51 commit 54108d9

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _check_function_param_validation(
191191
"sklearn.model_selection.train_test_split",
192192
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
193193
"sklearn.svm.l1_min_c",
194+
"sklearn.tree.export_text",
194195
"sklearn.utils.gen_batches",
195196
]
196197

sklearn/tree/_export.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted
20+
from ..utils._param_validation import Interval, validate_params
21+
2022
from ..base import is_classifier
2123

2224
from . import _criterion
2325
from . import _tree
2426
from ._reingold_tilford import buchheim, Tree
25-
from . import DecisionTreeClassifier
27+
from . import DecisionTreeClassifier, DecisionTreeRegressor
2628

2729

2830
def _color_brew(n):
@@ -919,6 +921,17 @@ def compute_depth_(
919921
return max(depths)
920922

921923

924+
@validate_params(
925+
{
926+
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
927+
"feature_names": [list, None],
928+
"class_names": [list, None],
929+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
930+
"spacing": [Interval(Integral, 1, None, closed="left"), None],
931+
"decimals": [Interval(Integral, 0, None, closed="left"), None],
932+
"show_weights": ["boolean"],
933+
}
934+
)
922935
def export_text(
923936
decision_tree,
924937
*,
@@ -1011,21 +1024,12 @@ def export_text(
10111024
left_child_fmt = "{} {} > {}\n"
10121025
truncation_fmt = "{} {}\n"
10131026

1014-
if max_depth < 0:
1015-
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)
1016-
10171027
if feature_names is not None and len(feature_names) != tree_.n_features:
10181028
raise ValueError(
10191029
"feature_names must contain %d elements, got %d"
10201030
% (tree_.n_features, len(feature_names))
10211031
)
10221032

1023-
if spacing <= 0:
1024-
raise ValueError("spacing must be > 0, given %d" % spacing)
1025-
1026-
if decimals < 0:
1027-
raise ValueError("decimals must be >= 0, given %d" % decimals)
1028-
10291033
if isinstance(decision_tree, DecisionTreeClassifier):
10301034
value_fmt = "{}{} weights: {}\n"
10311035
if not show_weights:

sklearn/tree/tests/test_export.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,6 @@ def test_precision():
350350
def test_export_text_errors():
351351
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
352352
clf.fit(X, y)
353-
354-
err_msg = "max_depth bust be >= 0, given -1"
355-
with pytest.raises(ValueError, match=err_msg):
356-
export_text(clf, max_depth=-1)
357353
err_msg = "feature_names must contain 2 elements, got 1"
358354
with pytest.raises(ValueError, match=err_msg):
359355
export_text(clf, feature_names=["a"])
@@ -364,12 +360,6 @@ def test_export_text_errors():
364360
)
365361
with pytest.raises(ValueError, match=err_msg):
366362
export_text(clf, class_names=["a"])
367-
err_msg = "decimals must be >= 0, given -1"
368-
with pytest.raises(ValueError, match=err_msg):
369-
export_text(clf, decimals=-1)
370-
err_msg = "spacing must be > 0, given 0"
371-
with pytest.raises(ValueError, match=err_msg):
372-
export_text(clf, spacing=0)
373363

374364

375365
def test_export_text():

0 commit comments

Comments
 (0)
0