8000 [MRG+2] Fix for issue #6352 by tracer0tong · Pull Request #6376 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Fix for issue #6352 #6376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sklearn/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ def get_color(value):
# Classification tree
color = list(colors['rgb'][np.argmax(value)])
sorted_values = sorted(value, reverse=True)
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
(1 - sorted_values[1]), 0))
if len(sorted_values) == 1:
alpha = 0
else:
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
(1 - sorted_values[1]), 0))
else:
# Regression tree or multi-output
color = list(colors['rgb'][0])
Expand Down Expand Up @@ -310,7 +313,7 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
# Find max and min impurities for multi-output
colors['bounds'] = (np.min(-tree.impurity),
np.max(-tree.impurity))
elif tree.n_classes[0] == 1:
elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
# Find max and min values in leaf nodes for regression
colors['bounds'] = (np.min(tree.value),
np.max(tree.value))
Expand Down
15 changes: 15 additions & 0 deletions sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
y = [-1, -1, -1, 1, 1, 1]
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
w = [1, 1, 1, .5, .5, .5]
y_degraded = [1, 1, 1, 1, 1, 1]


def test_graphviz_toy():
Expand Down Expand Up @@ -207,6 +208,20 @@ def test_graphviz_toy():

assert_equal(contents1, contents2)

# Test classifier with degraded learning set
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X, y_degraded)

out = StringIO()
export_graphviz(clf, out_file=out, filled=True)
contents1 = out.getvalue()
contents2 = 'digraph Tree {\n' \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note from PEP8:

The preferred way of wrapping long lines is by using Python's implied line continuation inside parentheses, brackets and braces. Long lines can be broken over multiple lines by wrapping expressions in parentheses. These should be used in preference to using a backslash for line continuation.

For long strings you can do:

contents2 = ('digraph Tree {\n'
             'node ...\n'
             '0 [label ...]\n'
             '}')

Alternatively in your particular case, this would work:

contents2 = '\n'.join(['digraph Tree{', 'node ...', '0 [label ...', '...'])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used this line separator because of consistency. All other lines in this file divided with backslash. I think in this case I should replace all backslashes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So. Keep it with backslashes? Or apply PEP8 recommendation to all strings?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep it with backslashes is fine, that's what I meant.

'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", fillcolor="#e5813900"] ;\n' \
'}'

assert_equal(contents1, contents2)


def test_graphviz_errors():
# Check for errors of export_graphviz
Expand Down
0