|
17 | 17 | import numpy as np
|
18 | 18 |
|
19 | 19 | from ..utils.validation import check_is_fitted, check_array
|
20 |
| -from ..utils._param_validation import Interval, validate_params, StrOptions |
| 20 | +from ..utils._param_validation import Interval, validate_params, StrOptions, HasMethods |
21 | 21 |
|
22 | 22 | from ..base import is_classifier
|
23 | 23 |
|
@@ -441,20 +441,6 @@ def __init__(
|
441 | 441 | else:
|
442 | 442 | self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']
|
443 | 443 |
|
444 |
| - # validate |
445 |
| - if isinstance(precision, Integral): |
446 |
| - if precision < 0: |
447 |
| - raise ValueError( |
448 |
| - "'precision' should be greater or equal to 0." |
449 |
| - " Got {} instead.".format(precision) |
450 |
| - ) |
451 |
| - else: |
452 |
| - raise ValueError( |
453 |
| - "'precision' should be an integer. Got {} instead.".format( |
454 |
| - type(precision) |
455 |
| - ) |
456 |
| - ) |
457 |
| - |
458 | 444 | # The depth of each node for plotting with 'leaf' option
|
459 | 445 | self.ranks = {"leaves": []}
|
460 | 446 | # The colors to render each node with
|
@@ -739,6 +725,26 @@ def recurse(self, node, tree, ax, max_x, max_y, depth=0):
|
739 | 725 | ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
|
740 | 726 |
|
741 | 727 |
|
| 728 | +@validate_params( |
| 729 | + { |
| 730 | + "decision_tree": "no_validation", |
| 731 | + "out_file": [str, None, HasMethods("write")], |
| 732 | + "max_depth": [Interval(Integral, 0, None, closed="left"), None], |
| 733 | + "feature_names": ["array-like", None], |
| 734 | + "class_names": ["array-like", "boolean", None], |
| 735 | + "label": [StrOptions({"all", "root", "none"})], |
| 736 | + "filled": ["boolean"], |
| 737 | + "leaves_parallel": ["boolean"], |
| 738 | + "impurity": ["boolean"], |
| 739 | + "node_ids": ["boolean"], |
| 740 | + "proportion": ["boolean"], |
| 741 | + "rotate": ["boolean"], |
| 742 | + "rounded": ["boolean"], |
| 743 | + "special_characters": ["boolean"], |
| 744 | + "precision": [Interval(Integral, 0, None, closed="left"), None], |
| 745 | + "fontname": [str], |
| 746 | + } |
| 747 | +) |
742 | 748 | def export_graphviz(
|
743 | 749 | decision_tree,
|
744 | 750 | out_file=None,
|
@@ -774,8 +780,8 @@ def export_graphviz(
|
774 | 780 |
|
775 | 781 | Parameters
|
776 | 782 | ----------
|
777 |
| - decision_tree : decision tree classifier |
778 |
| - The decision tree to be exported to GraphViz. |
| 783 | + decision_tree : object |
| 784 | + The decision tree estimator to be exported to GraphViz. |
779 | 785 |
|
780 | 786 | out_file : object or str, default=None
|
781 | 787 | Handle or name of the output file. If ``None``, the result is
|
|
0 commit comments