8000 FIX `tree.export_text` and `tree.export_graphviz` accepts feature and… · REDVM/scikit-learn@c52cfdb · GitHub
[go: up one dir, main page]

Skip to content

Commit c52cfdb

Browse files
Charlie-XIAOREDVM
authored andcommitted
FIX tree.export_text and tree.export_graphviz accepts feature and class names as array-like (scikit-learn#26289)
1 parent 5cd2f77 commit c52cfdb

File tree

3 files changed

+114
-72
lines changed

3 files changed

+114
-72
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,10 @@ Changelog
528528
for each target class in ascending numerical order.
529529
:pr:`25387` by :user:`William M <Akbeeh>` and :user:`crispinlogan <crispinlogan>`.
530530

531+
- |Fix| :func:`tree.export_graphviz` and :func:`tree.export_text` now accepts
532+
`feature_names` and `class_names` as array-like rather than lists.
533+
:pr:`26289` by :user:`Yao Xiao <Charlie-XIAO>`
534+
531535
:mod:`sklearn.utils`
532536
....................
533537

sklearn/tree/_export.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
from ..utils.validation import check_is_fitted
19+
from ..utils.validation import check_is_fitted, check_array
2020
from ..utils._param_validation import Interval, validate_params, StrOptions
2121

2222
from ..base import is_classifier
@@ -788,11 +788,11 @@ def export_graphviz(
788788
The maximum depth of the representation. If None, the tree is fully
789789
generated.
790790
791-
feature_names : list of str, default=None
792-
Names of each of the features.
791+
feature_names : array-like of shape (n_features,), default=None
792+
An array containing the feature names.
793793
If None, generic names will be used ("x[0]", "x[1]", ...).
794794
795-
class_names : list of str or bool, default=None
795+
class_names : array-like of shape (n_classes,) or bool, default=None
796796
Names of each of the target classes in ascending numerical order.
797797
Only relevant for classification and not supported for multi-output.
798798
If ``True``, shows a symbolic representation of the class name.
@@ -857,6 +857,14 @@ def export_graphviz(
857857
>>> tree.export_graphviz(clf)
858858
'digraph Tree {...
859859
"""
860+
if feature_names is not None:
861+
feature_names = check_array(
862+
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
863+
)
864+
if class_names is not None and not isinstance(class_names, bool):
865+
class_names = check_array(
866+
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
867+
)
860868

861869
check_is_fitted(decision_tree)
862870
own_file = False
@@ -924,8 +932,8 @@ def compute_depth_(
924932
@validate_params(
925933
{
926934
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
927-
"feature_names": [list, None],
928-
"class_names": [list, None],
935+
"feature_names": ["array-like", None],
936+
"class_names": ["array-like", None],
929937
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
930938
"spacing": [Interval(Integral, 1, None, closed="left"), None],
931939
"decimals": [Interval(Integral, 0, None, closed="left"), None],
@@ -953,17 +961,17 @@ def export_text(
953961
It can be an instance of
954962
DecisionTreeClassifier or DecisionTreeRegressor.
955963
956-
feature_names : list of str, default=None
957-
A list of length n_features containing the feature names.
964+
feature_names : array-like of shape (n_features,), default=None
965+
An array containing the feature names.
958966
If None generic names will be used ("feature_0", "feature_1", ...).
959967
960-
class_names : list or None, default=None
968+
class_names : array-like of shape (n_classes,), default=None
961969
Names of each of the target classes in ascending numerical order.
962970
Only relevant for classification and not supported for multi-output.
963971
964972
- if `None`, the class names are delegated to `decision_tree.classes_`;
965-
- if a list, then `class_names` will be used as class names instead
966-
of `decision_tree.classes_`. The length of `class_names` must match
973+
- otherwise, `class_names` will be used as class names instead of
974+
`decision_tree.classes_`. The length of `class_names` must match
967975
the length of `decision_tree.classes_`.
968976
969977
.. versionadded:: 1.3
@@ -1008,14 +1016,23 @@ def export_text(
10081016
| |--- petal width (cm) > 1.75
10091017
| | |--- class: 2
10101018
"""
1019+
if feature_names is not None:
1020+
feature_names = check_array(
1021+
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
1022+
)
1023+
if class_names is not None:
1024+
class_names = check_array(
1025+
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
1026+
)
1027+
10111028
check_is_fitted(decision_tree)
10121029
tree_ = decision_tree.tree_
10131030
if is_classifier(decision_tree):
10141031
if class_names is None:
10151032
class_names = decision_tree.classes_
10161033
elif len(class_names) != len(decision_tree.classes_):
10171034
raise V F438 alueError(
1018-
"When `class_names` is a list, it should contain as"
1035+
"When `class_names` is an array, it should contain as"
10191036
" many items as `decision_tree.classes_`. Got"
10201037
f" {len(class_names)} while the tree was fitted with"
10211038
f" {len(decision_tree.classes_)} classes."
@@ -1037,7 +1054,7 @@ def export_text(
10371054
else:
10381055
value_fmt = "{}{} value: {}\n"
10391056

1040-
if feature_names:
1057+
if feature_names is not None:
10411058
feature_names_ = [
10421059
feature_names[i] if i != _tree.TREE_UNDEFINED else None
10431060
for i in tree_.feature

sklearn/tree/tests/test_export.py

Lines changed: 80 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from re import finditer, search
55
from textwrap import dedent
66

7+
import numpy as np
78
from numpy.random import RandomState
89
import pytest
910

@@ -48,48 +49,6 @@ def test_graphviz_toy():
4849

4950
assert contents1 == contents2
5051

51-
# Test with feature_names
52-
contents1 = export_graphviz(
53-
clf, feature_names=["feature0", "feature1"], out_file=None
54-
)
55-
contents2 = (
56-
"digraph Tree {\n"
57-
'node [shape=box, fontname="helvetica"] ;\n'
58-
'edge [fontname="helvetica"] ;\n'
59-
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
60-
'value = [3, 3]"] ;\n'
61-
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
62-
"0 -> 1 [labeldistance=2.5, labelangle=45, "
63-
'headlabel="True"] ;\n'
64-
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
65-
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
66-
'headlabel="False"] ;\n'
67-
"}"
68-
)
69-
70-
assert contents1 == contents2
71-
72-
# Test with class_names
73-
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
74-
contents2 = (
75-
"digraph Tree {\n"
76-
'node [shape=box, fontname="helvetica"] ;\n'
77-
'edge [fontname="helvetica"] ;\n'
78-
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
79-
'value = [3, 3]\\nclass = yes"] ;\n'
80-
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
81-
'class = yes"] ;\n'
82-
"0 -> 1 [labeldistance=2.5, labelangle=45, "
83-
'headlabel="True"] ;\n'
84-
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
85-
'class = no"] ;\n'
86-
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
87-
'headlabel="False"] ;\n'
88-
"}"
89-
)
90-
91-
assert contents1 == contents2
92-
9352
# Test plot_options
9453
contents1 = export_graphviz(
9554
clf,
@@ -249,6 +208,60 @@ def test_graphviz_toy():
249208
)
250209

251210

211+
@pytest.mark.parametrize("constructor", [list, np.array])
212+
def test_graphviz_feature_class_names_array_support(constructor):
213+
# Check that export_graphviz treats feature names
214+
# and class names correctly and supports arrays
215+
clf = DecisionTreeClassifier(
216+
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
217+
)
218+
clf.fit(X, y)
219+
220+
# Test with feature_names
221+
contents1 = export_graphviz(
222+
clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
223+
)
224+
contents2 = (
225+
"digraph Tree {\n"
226+
'node [shape=box, fontname="helvetica"] ;\n'
227+
'edge [fontname="helvetica"] ;\n'
228+
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
229+
'value = [3, 3]"] ;\n'
230+
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
231+
"0 -> 1 [labeldistance=2.5, labelangle=45, "
232+
'headlabel="True"] ;\n'
233+
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
234+
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
235+
'headlabel="False"] ;\n'
236+
"}"
237+
)
238+
239+
assert contents1 == contents2
240+
241+
# Test with class_names
242+
contents1 = export_graphviz(
243+
clf, class_names=constructor(["yes", "no"]), out_file=None
244+
)
245+
contents2 = (
246+
"digraph Tree {\n"
247+
'node [shape=box, fontname="helvetica"] ;\n'
248+
'edge [fontname="helvetica"] ;\n'
249+
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
250+
'value = [3, 3]\\nclass = yes"] ;\n'
251+
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
252+
'class = yes"] ;\n'
253+
"0 -> 1 [labeldistance=2.5, labelangle=45, "
254+
'headlabel="True"] ;\n'
255+
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
256+
'class = no"] ;\n'
257+
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
258+
'headlabel="False"] ;\n'
259+
"}"
260+
)
261+
262+
assert contents1 == contents2
263+
264+
252265
def test_graphviz_errors():
253266
# Check for errors of export_graphviz
254267
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
@@ -352,7 +365,7 @@ def test_export_text_errors():
352365
with pytest.raises(ValueError, match=err_msg):
353366
export_text(clf, feature_names=["a"])
354367
err_msg = (
355-
"When `class_names` is a list, it should contain as"
368+
"When `class_names` is an array, it should contain as"
356369
" many items as `decision_tree.classes_`. Got 1 while"
357370
" the tree was fitted with 2 classes."
358371
)
@@ -377,22 +390,6 @@ def test_export_text():
377390
# testing that the rest of the tree is truncated
378391
assert export_text(clf, max_depth=10) == expected_report
379392

380-
expected_report = dedent("""
381-
|--- b <= 0.00
382-
| |--- class: -1
383-
|--- b > 0.00
384-
| |--- class: 1
385-
""").lstrip()
386-
assert export_text(clf, feature_names=["a", "b"]) == expected_report
387-
388-
expected_report = dedent("""
389-
|--- feature_1 <= 0.00
390-
| |--- class: cat
391-
|--- feature_1 > 0.00
392-
| |--- class: dog
393-
""").lstrip()
394-
assert export_text(clf, class_names=["cat", "dog"]) == expected_report
395-
396393
expected_report = dedent("""
397394
|--- feature_1 <= 0.00
398395
| |--- weights: [3.00, 0.00] class: -1
@@ -453,6 +450,30 @@ def test_export_text():
453450
)
454451

455452

453+
@pytest.mark.parametrize("constructor", [list, np.array])
454+
def test_export_text_feature_class_names_array_support(constructor):
455+
# Check that export_graphviz treats feature names
456+
# and class names correctly and supports arrays
457+
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
458+
clf.fit(X, y)
459+
460+
expected_report = dedent("""
461+
|--- b <= 0.00
462+
| |--- class: -1
463+
|--- b > 0.00
464+
| |--- class: 1
465+
""").lstrip()
466+
assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
467+
468+
expected_report = dedent("""
469+
|--- feature_1 <= 0.00
470+
| |--- class: cat
471+
|--- feature_1 > 0.00
472+
| |--- class: dog
473+
""").lstrip()
474+
assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
475+
476+
456477
def test_plot_tree_entropy(pyplot):
457478
# mostly smoke tests
458479
# Check correctness of export_graphviz for criterion = entropy

0 commit comments

Comments
 (0)
0