8000 Merge pull request #6376 from tracer0tong/issue_6352 · scikit-learn/scikit-learn@afc058f · GitHub
[go: up one dir, main page]

Skip to content

Commit afc058f

Browse files
committed
Merge pull request #6376 from tracer0tong/issue_6352
[MRG+2] Fix for issue #6352
2 parents eed5fc5 + 65a2b8f commit afc058f

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

sklearn/tree/export.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ def get_color(value):
152152
# Classification tree
153153
color = list(colors['rgb'][np.argmax(value)])
154154
sorted_values = sorted(value, reverse=True)
155-
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
156-
(1 - sorted_values[1]), 0))
155+
if len(sorted_values) == 1:
156+
alpha = 0
157+
else:
158+
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
159+
(1 - sorted_values[1]), 0))
157160
else:
158161
# Regression tree or multi-output
159162
color = list(colors['rgb'][0])
@@ -310,7 +313,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
310313
# Find max and min impurities for multi-output
311314
colors['bounds'] = (np.min(-tree.impurity),
312315
np.max(-tree.impurity))
313-
elif tree.n_classes[0] == 1:
316+
elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
314317
# Find max and min values in leaf nodes for regression
315318
colors['bounds'] = (np.min(tree.value),
316319
np.max(tree.value))

sklearn/tree/tests/test_export.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
y = [-1, -1, -1, 1, 1, 1]
1919
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
2020
w = [1, 1, 1, .5, .5, .5]
21+
y_degraded = [1, 1, 1, 1, 1, 1]
2122

2223

2324
def test_graphviz_toy():
@@ -207,6 +208,20 @@ def test_graphviz_toy():
207208

208209
assert_equal(contents1, contents2)
209210

211+
# Test classifier with degraded learning set
212+
clf = DecisionTreeClassifier(max_depth=3)
213+
clf.fit(X, y_degraded)
214+
215+
out = StringIO()
216+
export_graphviz(clf, out_file=out, filled=True)
217+
contents1 = out.getvalue()
218+
contents2 = 'digraph Tree {\n' \
219+
'node [shape=box, style="filled", color="black"] ;\n' \
220+
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", fillcolor="#e5813900"] ;\n' \
221+
'}'
222+
223+
assert_equal(contents1, contents2)
224+
210225

211226
def test_graphviz_errors():
212227
# Check for errors of export_graphviz

0 commit comments

Comments
 (0)
0