8000 MAINT Parameters validation for sklearn.tree.export_graphviz (#26034) · REDVM/scikit-learn@b101a2e · GitHub
[go: up one dir, main page]

Skip to content

Commit b101a2e

Browse files
Charlie-XIAOglemaitre
authored andcommitted
MAINT Parameters validation for sklearn.tree.export_graphviz (scikit-learn#26034)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent bbf8736 commit b101a2e

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _check_function_param_validation(
271271
"sklearn.preprocessing.scale",
272272
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
273273
"sklearn.svm.l1_min_c",
274+
"sklearn.tree.export_graphviz",
274275
"sklearn.tree.export_text",
275276
"sklearn.tree.plot_tree",
276277
"sklearn.utils.gen_batches",

sklearn/tree/_export.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted, check_array
20-
from ..utils._param_validation import Interval, validate_params, StrOptions
20+
from ..utils._param_validation import Interval, validate_params, StrOptions, HasMethods
2121

2222
from ..base import is_classifier
2323

@@ -441,20 +441,6 @@ def __init__(
441441
else:
442442
self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']
443443

444-
# validate
445-
if isinstance(precision, Integral):
446-
if precision < 0:
447-
raise ValueError(
448-
"'precision' should be greater or equal to 0."
449-
" Got {} instead.".format(precision)
450-
)
451-
else:
452-
raise ValueError(
453-
"'precision' should be an integer. Got {} instead.".format(
454-
type(precision)
455-
)
456-
)
457-
458444
# The depth of each node for plotting with 'leaf' option
459445
self.ranks = {"leaves": []}
460446
# The colors to render each node with
@@ -739,6 +725,26 @@ def recurse(self, node, tree, ax, max_x, max_y, depth=0):
739725
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
740726

741727

728+
@validate_params(
729+
{
730+
"decision_tree": "no_validation",
731+
"out_file": [str, None, HasMethods("write")],
732+
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
733+
"feature_names": ["array-like", None],
734+
"class_names": ["array-like", "boolean", None],
735+
"label": [StrOptions({"all", "root", "none"})],
736+
"filled": ["boolean"],
737+
"leaves_parallel": ["boolean"],
738+
"impurity": ["boolean"],
739+
"node_ids": ["boolean"],
740+
"proportion": ["boolean"],
741+
"rotate": ["boolean"],
742+
"rounded": ["boolean"],
743+
"special_characters": ["boolean"],
744+
"precision": [Interval(Integral, 0, None, closed="left"), None],
745+
"fontname": [str],
746+
}
747+
)
742748
def export_graphviz(
743749
decision_tree,
744750
out_file=None,
@@ -774,8 +780,8 @@ def export_graphviz(
774780
775781
Parameters
776782
----------
777-
decision_tree : decision tree classifier
778-
The decision tree to be exported to GraphViz.
783+
decision_tree : object
784+
The decision tree estimator to be exported to GraphViz.
779785
780786
out_file : object or str, default=None
781787
Handle or name of the output file. If ``None``, the result is

sklearn/tree/tests/test_export.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,6 @@ def test_graphviz_errors():
293293
with pytest.raises(IndexError):
294294
export_graphviz(clf, out, class_names=[])
295295

296-
# Check precision error
297-
out = StringIO()
298-
with pytest.raises(ValueError, match="should be greater or equal"):
299-
export_graphviz(clf, out, precision=-1)
300-
with pytest.raises(ValueError, match="should be an integer"):
301-
export_graphviz(clf, out, precision="1")
302-
303296

304297
def test_friedman_mse_in_graphviz():
305298
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)

0 commit comments

Comments
 (0)
0