44from re import finditer , search
55from textwrap import dedent
66
7+ import numpy as np
78from numpy .random import RandomState
89import pytest
910
@@ -48,48 +49,6 @@ def test_graphviz_toy():
4849
4950 assert contents1 == contents2
5051
51- # Test with feature_names
52- contents1 = export_graphviz (
53- clf , feature_names = ["feature0" , "feature1" ], out_file = None
54- )
55- contents2 = (
56- "digraph Tree {\n "
57- 'node [shape=box, fontname="helvetica"] ;\n '
58- 'edge [fontname="helvetica"] ;\n '
59- '0 [label="feature0 <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
60- 'value = [3, 3]"] ;\n '
61- '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]"] ;\n '
62- "0 -> 1 [labeldistance=2.5, labelangle=45, "
63- 'headlabel="True"] ;\n '
64- '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]"] ;\n '
65- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
66- 'headlabel="False"] ;\n '
67- "}"
68- )
69-
70- assert contents1 == contents2
71-
72- # Test with class_names
73- contents1 = export_graphviz (clf , class_names = ["yes" , "no" ], out_file = None )
74- contents2 = (
75- "digraph Tree {\n "
76- 'node [shape=box, fontname="helvetica"] ;\n '
77- 'edge [fontname="helvetica"] ;\n '
78- '0 [label="x[0] <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
79- 'value = [3, 3]\\ nclass = yes"] ;\n '
80- '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]\\ n'
81- 'class = yes"] ;\n '
82- "0 -> 1 [labeldistance=2.5, labelangle=45, "
83- 'headlabel="True"] ;\n '
84- '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]\\ n'
85- 'class = no"] ;\n '
86- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
87- 'headlabel="False"] ;\n '
88- "}"
89- )
90-
91- assert contents1 == contents2
92-
9352 # Test plot_options
9453 contents1 = export_graphviz (
9554 clf ,
@@ -249,6 +208,60 @@ def test_graphviz_toy():
249208 )
250209
251210
211+ @pytest .mark .parametrize ("constructor" , [list , np .array ])
212+ def test_graphviz_feature_class_names_array_support (constructor ):
213+ # Check that export_graphviz treats feature names
214+ # and class names correctly and supports arrays
215+ clf = DecisionTreeClassifier (
216+ max_depth = 3 , min_samples_split = 2 , criterion = "gini" , random_state = 2
217+ )
218+ clf .fit (X , y )
219+
220+ # Test with feature_names
221+ contents1 = export_graphviz (
222+ clf , feature_names = constructor (["feature0" , "feature1" ]), out_file = None
223+ )
224+ contents2 = (
225+ "digraph Tree {\n "
226+ 'node [shape=box, fontname="helvetica"] ;\n '
227+ 'edge [fontname="helvetica"] ;\n '
228+ '0 [label="feature0 <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
229+ 'value = [3, 3]"] ;\n '
230+ '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]"] ;\n '
231+ "0 -> 1 [labeldistance=2.5, labelangle=45, "
232+ 'headlabel="True"] ;\n '
233+ '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]"] ;\n '
234+ "0 -> 2 [labeldistance=2.5, labelangle=-45, "
235+ 'headlabel="False"] ;\n '
236+ "}"
237+ )
238+
239+ assert contents1 == contents2
240+
241+ # Test with class_names
242+ contents1 = export_graphviz (
243+ clf , class_names = constructor (["yes" , "no" ]), out_file = None
244+ )
245+ contents2 = (
246+ "digraph Tree {\n "
247+ 'node [shape=box, fontname="helvetica"] ;\n '
248+ 'edge [fontname="helvetica"] ;\n '
249+ '0 [label="x[0] <= 0.0\\ ngini = 0.5\\ nsamples = 6\\ n'
250+ 'value = [3, 3]\\ nclass = yes"] ;\n '
251+ '1 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [3, 0]\\ n'
252+ 'class = yes"] ;\n '
253+ "0 -> 1 [labeldistance=2.5, labelangle=45, "
254+ 'headlabel="True"] ;\n '
255+ '2 [label="gini = 0.0\\ nsamples = 3\\ nvalue = [0, 3]\\ n'
256+ 'class = no"] ;\n '
257+ "0 -> 2 [labeldistance=2.5, labelangle=-45, "
258+ 'headlabel="False"] ;\n '
259+ "}"
260+ )
261+
262+ assert contents1 == contents2
263+
264+
252265def test_graphviz_errors ():
253266 # Check for errors of export_graphviz
254267 clf = DecisionTreeClassifier (max_depth = 3 , min_samples_split = 2 )
@@ -352,7 +365,7 @@ def test_export_text_errors():
352365 with pytest .raises (ValueError , match = err_msg ):
353366 export_text (clf , feature_names = ["a" ])
354367 err_msg = (
355- "When `class_names` is a list , it should contain as"
368+ "When `class_names` is an array , it should contain as"
356369 " many items as `decision_tree.classes_`. Got 1 while"
357370 " the tree was fitted with 2 classes."
358371 )
@@ -377,22 +390,6 @@ def test_export_text():
377390 # testing that the rest of the tree is truncated
378391 assert export_text (clf , max_depth = 10 ) == expected_report
379392
380- expected_report = dedent ("""
381- |--- b <= 0.00
382- | |--- class: -1
383- |--- b > 0.00
384- | |--- class: 1
385- """ ).lstrip ()
386- assert export_text (clf , feature_names = ["a" , "b" ]) == expected_report
387-
388- expected_report = dedent ("""
389- |--- feature_1 <= 0.00
390- | |--- class: cat
391- |--- feature_1 > 0.00
392- | |--- class: dog
393- """ ).lstrip ()
394- assert export_text (clf , class_names = ["cat" , "dog" ]) == expected_report
395-
396393 expected_report = dedent ("""
397394 |--- feature_1 <= 0.00
398395 | |--- weights: [3.00, 0.00] class: -1
@@ -453,6 +450,30 @@ def test_export_text():
453450 )
454451
455452
453+ @pytest .mark .parametrize ("constructor" , [list , np .array ])
454+ def test_export_text_feature_class_names_array_support (constructor ):
455+ # Check that export_graphviz treats feature names
456+ # and class names correctly and supports arrays
457+ clf = DecisionTreeClassifier (max_depth = 2 , random_state = 0 )
458+ clf .fit (X , y )
459+
460+ expected_report = dedent ("""
461+ |--- b <= 0.00
462+ | |--- class: -1
463+ |--- b > 0.00
464+ | |--- class: 1
465+ """ ).lstrip ()
466+ assert export_text (clf , feature_names = constructor (["a" , "b" ])) == expected_report
467+
468+ expected_report = dedent ("""
469+ |--- feature_1 <= 0.00
470+ | |--- class: cat
471+ |--- feature_1 > 0.00
472+ | |--- class: dog
473+ """ ).lstrip ()
474+ assert export_text (clf , class_names = constructor (["cat" , "dog" ])) == expected_report
475+
476+
456477def test_plot_tree_entropy (pyplot ):
457478 # mostly smoke tests
458479 # Check correctness of export_graphviz for criterion = entropy
0 commit comments