diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index ff1355db478ee..5f58aed546bfc 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -152,8 +152,11 @@ def get_color(value): # Classification tree color = list(colors['rgb'][np.argmax(value)]) sorted_values = sorted(value, reverse=True) - alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) / - (1 - sorted_values[1]), 0)) + if len(sorted_values) == 1: + alpha = 0 + else: + alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) / + (1 - sorted_values[1]), 0)) else: # Regression tree or multi-output color = list(colors['rgb'][0]) @@ -310,7 +313,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0): # Find max and min impurities for multi-output colors['bounds'] = (np.min(-tree.impurity), np.max(-tree.impurity)) - elif tree.n_classes[0] == 1: + elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1: # Find max and min values in leaf nodes for regression colors['bounds'] = (np.min(tree.value), np.max(tree.value)) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 8a7421923b152..8d954d3cc2526 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -18,6 +18,7 @@ y = [-1, -1, -1, 1, 1, 1] y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]] w = [1, 1, 1, .5, .5, .5] +y_degraded = [1, 1, 1, 1, 1, 1] def test_graphviz_toy(): @@ -207,6 +208,20 @@ def test_graphviz_toy(): assert_equal(contents1, contents2) + # Test classifier with degraded learning set + clf = DecisionTreeClassifier(max_depth=3) + clf.fit(X, y_degraded) + + out = StringIO() + export_graphviz(clf, out_file=out, filled=True) + contents1 = out.getvalue() + contents2 = 'digraph Tree {\n' \ + 'node [shape=box, style="filled", color="black"] ;\n' \ + '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", fillcolor="#e5813900"] ;\n' \ + '}' + + assert_equal(contents1, contents2) + def test_graphviz_errors(): # Check for errors of export_graphviz