8000 FIX `export_text` and `export_graphviz` accepts feature and class names as array-like by Charlie-XIAO · Pull Request #26289 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX export_text and export_graphviz accepts feature and class names as array-like #26289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ Changelog
for each target class in ascending numerical order.
:pr:`25387` by :user:`William M <Akbeeh>` and :user:`crispinlogan <crispinlogan>`.

- |Fix| :func:`tree.export_graphviz` and :func:`tree.export_text` now accepts
`feature_names` and `class_names` as array-like rather than lists.
:pr:`26289` by :user:`Yao Xiao <Charlie-XIAO>`

:mod:`sklearn.utils`
....................

Expand Down
43 changes: 30 additions & 13 deletions sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, check_array
from ..utils._param_validation import Interval, validate_params, StrOptions

from ..base import is_classifier
Expand Down Expand Up @@ -788,11 +788,11 @@ def export_graphviz(
The maximum depth of the representation. If None, the tree is fully
generated.

feature_names : list of str, default=None
Names of each of the features.
feature_names : array-like of shape (n_features,), default=None
An array containing the feature names.
If None, generic names will be used ("x[0]", "x[1]", ...).

class_names : list of str or bool, default=None
class_names : array-like of shape (n_classes,) or bool, default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.
If ``True``, shows a symbolic representation of the class name.
Expand Down Expand Up @@ -857,6 +857,14 @@ def export_graphviz(
>>> tree.export_graphviz(clf)
'digraph Tree {...
"""
if feature_names is not None:
feature_names = check_array(
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)
if class_names is not None and not isinstance(class_names, bool):
class_names = check_array(
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)

check_is_fitted(decision_tree)
own_file = False
Expand Down Expand Up @@ -924,8 +932,8 @@ def compute_depth_(
@validate_params(
{
"decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
"feature_names": [list, None],
"class_names": [list, None],
"feature_names": ["array-like", None],
"class_names": ["array-like", None],
"max_depth": [Interval(Integral, 0, None, closed="left"), None],
"spacing": [Interval(Integral, 1, None, closed="left"), None],
"decimals": [Interval(Integral, 0, None, closed="left"), None],
Expand Down Expand Up @@ -953,17 +961,17 @@ def export_text(
It can be an instance of
DecisionTreeClassifier or DecisionTreeRegressor.

feature_names : list of str, default=None
A list of length n_features containing the feature names.
feature_names : array-like of shape (n_features,), default=None
An array containing the feature names.
If None generic names will be used ("feature_0", "feature_1", ...).

class_names : list or None, default=None
class_names : array-like of shape (n_classes,), default=None
Names of each of the target classes in ascending numerical order.
Only relevant for classification and not supported for multi-output.

- if `None`, the class names are delegated to `decision_tree.classes_`;
- if a list, then `class_names` will be used as class names instead
of `decision_tree.classes_`. The length of `class_names` must match
- otherwise, `class_names` will be used as class names instead of
`decision_tree.classes_`. The length of `class_names` must match
the length of `decision_tree.classes_`.

.. versionadded:: 1.3
Expand Down Expand Up @@ -1008,14 +1016,23 @@ def export_text(
| |--- petal width (cm) > 1.75
| | |--- class: 2
"""
if feature_names is not None:
feature_names = check_array(
feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)
if class_names is not None:
class_names = check_array(
class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
)

check_is_fitted(decision_tree)
tree_ = decision_tree.tree_
if is_classifier(decision_tree):
if class_names is None:
class_names = decision_tree.classes_
elif len(class_names) != len(decision_tree.classes_):
raise ValueError(
"When `class_names` is a list, it should contain as"
"When `class_names` is an array, it should contain as"
" many items as `decision_tree.classes_`. Got"
f" {len(class_names)} while the tree was fitted with"
f" {len(decision_tree.classes_)} classes."
Expand All @@ -1037,7 +1054,7 @@ def export_text(
else:
value_fmt = "{}{} value: {}\n"

if feature_names:
if feature_names is not None:
feature_names_ = [
feature_names[i] if i != _tree.TREE_UNDEFINED else None
for i in tree_.feature
Expand Down
139 changes: 80 additions & 59 deletions sklearn/tree/tests/test_export.py
A935
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from re import finditer, search
from textwrap import dedent

import numpy as np
from numpy.random import RandomState
import pytest

Expand Down Expand Up @@ -48,48 +49,6 @@ def test_graphviz_toy():

assert contents1 == contents2

# Test with feature_names
contents1 = export_graphviz(
clf, feature_names=["feature0", "feature1"], out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)

assert contents1 == contents2

# Test with class_names
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = yes"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
'class = yes"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
'class = no"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)

assert contents1 == contents2

# Test plot_options
contents1 = export_graphviz(
clf,
Expand Down Expand Up @@ -249,6 +208,60 @@ def test_graphviz_toy():
)


@pytest.mark.parametrize("constructor", [list, np.array])
def test_graphviz_feature_class_names_array_support(constructor):
# Check that export_graphviz treats feature names
# and class names correctly and supports arrays
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
)
clf.fit(X, y)

# Test with feature_names
contents1 = export_graphviz(
clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)

assert contents1 == contents2

# Test with class_names
contents1 = export_graphviz(
clf, class_names=constructor(["yes", "no"]), out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = yes"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
'class = yes"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
'class = no"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)

assert contents1 == contents2


def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
Expand Down Expand Up @@ -352,7 +365,7 @@ def test_export_text_errors():
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=["a"])
err_msg = (
"When `class_names` is a list, it should contain as"
"When `class_names` is an array, it should contain as"
" many items as `decision_tree.classes_`. Got 1 while"
" the tree was fitted with 2 classes."
)
Expand All @@ -377,22 +390,6 @@ def test_export_text():
# testing that the rest of the tree is truncated
assert export_text(clf, max_depth=10) == expected_report

expected_report = dedent("""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
""").lstrip()
assert export_text(clf, feature_names=["a", "b"]) == expected_report

expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: cat
|--- feature_1 > 0.00
| |--- class: dog
""").lstrip()
assert export_text(clf, class_names=["cat", "dog"]) == expected_report

expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- weights: [3.00, 0.00] class: -1
Expand Down Expand Up @@ -453,6 +450,30 @@ def test_export_text():
)


@pytest.mark.parametrize("constructor", [list, np.array])
def test_export_text_feature_class_names_array_support(constructor):
# Check that export_graphviz treats feature names
# and class names correctly and supports arrays
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

expected_report = dedent("""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
""").lstrip()
assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report

expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: cat
|--- feature_1 > 0.00
| |--- class: dog
""").lstrip()
assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report


def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = entropy
Expand Down
0