|
17 | 17 | import numpy as np |
18 | 18 |
|
19 | 19 | from ..utils.validation import check_is_fitted |
20 | | -from ..utils._param_validation import Interval, validate_params |
| 20 | +from ..utils._param_validation import Interval, validate_params, StrOptions |
21 | 21 |
|
22 | 22 | from ..base import is_classifier |
23 | 23 |
|
@@ -77,6 +77,23 @@ def __repr__(self): |
77 | 77 | SENTINEL = Sentinel() |
78 | 78 |
|
79 | 79 |
|
| 80 | +@validate_params( |
| 81 | + { |
| 82 | + "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor], |
| 83 | + "max_depth": [Interval(Integral, 0, None, closed="left"), None], |
| 84 | + "feature_names": [list, None], |
| 85 | + "class_names": [list, None], |
| 86 | + "label": [StrOptions({"all", "root", "none"})], |
| 87 | + "filled": ["boolean"], |
| 88 | + "impurity": ["boolean"], |
| 89 | + "node_ids": ["boolean"], |
| 90 | + "proportion": ["boolean"], |
| 91 | + "rounded": ["boolean"], |
| 92 | + "precision": [Interval(Integral, 0, None, closed="left"), None], |
| 93 | + "ax": "no_validation", # delegate validation to matplotlib |
| 94 | + "fontsize": [Interval(Integral, 0, None, closed="left"), None], |
| 95 | + } |
| 96 | +) |
80 | 97 | def plot_tree( |
81 | 98 | decision_tree, |
82 | 99 | *, |
@@ -601,20 +618,6 @@ def __init__( |
601 | 618 | ) |
602 | 619 | self.fontsize = fontsize |
603 | 620 |
|
604 | | - # validate |
605 | | - if isinstance(precision, Integral): |
606 | | - if precision < 0: |
607 | | - raise ValueError( |
608 | | - "'precision' should be greater or equal to 0." |
609 | | - " Got {} instead.".format(precision) |
610 | | - ) |
611 | | - else: |
612 | | - raise ValueError( |
613 | | - "'precision' should be an integer. Got {} instead.".format( |
614 | | - type(precision) |
615 | | - ) |
616 | | - ) |
617 | | - |
618 | 621 | # The depth of each node for plotting with 'leaf' option |
619 | 622 | self.ranks = {"leaves": []} |
620 | 623 | # The colors to render each node with |
|
0 commit comments