8000 MAINT validate_params for plot_tree (#25882) · scikit-learn/scikit-learn@2c97116 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2c97116

Browse files
VeghitItay
authored andcommitted
MAINT validate_params for plot_tree (#25882)
Co-authored-by: Itay <itayvegh@gmail.com>
1 parent 0d64914 commit 2c97116

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def _check_function_param_validation(
194194
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
195195
"sklearn.svm.l1_min_c",
196196
"sklearn.tree.export_text",
197+
"sklearn.tree.plot_tree",
197198
"sklearn.utils.gen_batches",
198199
]
199200

sklearn/tree/_export.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted
20-
from ..utils._param_validation import Interval, validate_params
20+
from ..utils._param_validation import Interval, validate_params, StrOptions
2121

2222
from ..base import is_classifier
2323

@@ -77,6 +77,23 @@ def __repr__(self):
7777
SENTINEL = Sentinel()
7878

7979

80+
@validate_params(
81+
{
82+
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
83+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
84+
"feature_names": [list, None],
85+
"class_names": [list, None],
86+
"label": [StrOptions({"all", "root", "none"})],
87+
"filled": ["boolean"],
88+
"impurity": ["boolean"],
89+
"node_ids": ["boolean"],
90+
"proportion": ["boolean"],
91+
"rounded": ["boolean"],
92+
"precision": [Interval(Integral, 0, None, closed="left"), None],
93+
"ax": "no_validation", # delegate validation to matplotlib
94+
"fontsize": [Interval(Integral, 0, None, closed="left"), None],
95+
}
96+
)
8097
def plot_tree(
8198
decision_tree,
8299
*,
@@ -601,20 +618,6 @@ def __init__(
601618
)
602619
self.fontsize = fontsize
603620

604-
# validate
605-
if isinstance(precision, Integral):
606-
if precision < 0:
607-
raise ValueError(
608-
"'precision' should be greater or equal to 0."
609-
" Got {} instead.".format(precision)
610-
)
611-
else:
612-
raise ValueError(
613-
"'precision' should be an integer. Got {} instead.".format(
614-
type(precision)
615-
)
616-
)
617-
618621
# The depth of each node for plotting with 'leaf' option
619622
self.ranks = {"leaves": []}
620623
# The colors to render each node with

0 commit comments

Comments
 (0)
0