8000 MAINT Parameters validation for sklearn.tree.export_graphviz by Charlie-XIAO · Pull Request #26034 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Parameters validation for sklearn.tree.export_graphviz #26034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
461c798
MAINT Parameters validation for tree.export_graphviz
Charlie-XIAO Mar 31, 2023
13646dd
added to param val test
Charlie-XIAO Mar 31, 2023
a23ec8e
fixed constraints, removed unnecessary tests
Charlie-XIAO Mar 31, 2023
95322f3
resolved conversations
Charlie-XIAO Apr 1, 2023
655f289
added test cases for array-like
Charlie-XIAO Apr 1, 2023
ce6bdd3
fixed failing test cases
Charlie-XIAO Apr 1, 2023
630ca86
improved tests
Charlie-XIAO Apr 3, 2023
e1dc1cf
Merge branch 'main' into param-val-export_graphviz
glemaitre Apr 3, 2023
cd15e9b
Merge branch 'main' into param-val-export_graphviz
Charlie-XIAO Apr 3, 2023
7bc31c0
8000 resolved conversations
Charlie-XIAO Apr 3, 2023
b5622ac
Merge branch 'param-val-export_graphviz' of https://github.com/Charli…
Charlie-XIAO Apr 3, 2023
7fcecda
Merge remote-tracking branch 'upstream/main' into param-val-export_gr…
Charlie-XIAO Apr 20, 2023
0527d8b
Merge remote-tracking branch 'upstream/main' into param-val-export_gr…
Charlie-XIAO Apr 21, 2023
02eeb3f
Merge remote-tracking branch 'upstream/main' into param-val-export_gr…
Charlie-XIAO Apr 25, 2023
69a25f4
Merge remote-tracking branch 'upstream/main' into param-val-export_gr…
Charlie-XIAO Apr 26, 2023
b5d1f26
leave decision_tree as no validation
Charlie-XIAO Apr 27, 2023
db84ab9
Merge remote-tracking branch 'upstream/main' into param-val-export_gr…
Charlie-XIAO Apr 27, 2023
cd75117
reverted necessary test
Charlie-XIAO Apr 27, 2023
676462b
Merge branch 'main' into param-val-export_graphviz
Charlie-XIAO May 9, 2023
df2bf00
Merge branch 'main' into param-val-export_graphviz
Charlie-XIAO May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def _check_function_param_validation(
"sklearn.preprocessing.scale",
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
"sklearn.svm.l1_min_c",
"sklearn.tree.export_graphviz",
"sklearn.tree.export_text",
"sklearn.tree.plot_tree",
"sklearn.utils.gen_batches",
Expand Down
40 changes: 23 additions & 17 deletions sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np

from ..utils.validation import check_is_fitted, check_array
from ..utils._param_validation import Interval, validate_params, StrOptions
from ..utils._param_validation import Interval, validate_params, StrOptions, HasMethods

from ..base import is_classifier

Expand Down Expand Up @@ -441,20 +441,6 @@ def __init__(
else:
self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']

# validate
if isinstance(precision, Integral):
if precision < 0:
raise ValueError(
"'precision' should be greater or equal to 0."
" Got {} instead.".format(precision)
)
else:
raise ValueError(
"'precision' should be an integer. Got {} instead.".format(
type(precision)
)
)

# The depth of each node for plotting with 'leaf' option
self.ranks = {"leaves": []}
# The colors to render each node with
Expand Down Expand Up @@ -739,6 +725,26 @@ def recurse(self, node, tree, ax, max_x, max_y, depth=0):
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)


@validate_params(
{
"decision_tree": "no_validation",
"out_file": [str, None, HasMethods("write")],
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
"feature_names": ["array-like", None],
"class_names": ["array-like", "boolean", None],
"label": [StrOptions({"all", "root", "none"})],
"filled": ["boolean"],
"leaves_parallel": ["boolean"],
"impurity": ["boolean"],
"node_ids": ["boolean"],
"proportion": ["boolean"],
"rotate": ["boolean"],
"rounded": ["boolean"],
"special_characters": ["boolean"],
"precision": [Interval(Integral, 0, None, closed="left"), None],
"fontname": [str],
}
)
def export_graphviz(
decision_tree,
out_file=None,
Expand Down Expand Up @@ -774,8 +780,8 @@ def export_graphviz(

Parameters
----------
decision_tree : decision tree classifier
The decision tree to be exported to GraphViz.
decision_tree : object
The decision tree estimator to be exported to GraphViz.

out_file : object or str, default=None
Handle or name of the output file. If ``None``, the result is
Expand Down
7 changes: 0 additions & 7 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,6 @@ def test_graphviz_errors():
with pytest.raises(IndexError):
export_graphviz(clf, out, class_names=[])

# Check precision error
out = StringIO()
with pytest.raises(ValueError, match="should be greater or equal"):
export_graphviz(clf, out, precision=-1)
with pytest.raises(ValueError, match="should be an integer"):
export_graphviz(clf, out, precision="1")


def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
Expand Down
0