10000 FIX Fixes plot_tree from going out of bounds by thomasjpfan · Pull Request #21917 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Fixes plot_tree from going out of bounds #21917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ Changelog
- |Fix| Fixes compatibility bug with NumPy 1.22 in :class:`preprocessing.OneHotEncoder`.
:pr:`21517` by `Thomas Fan`_.

:mod:`sklearn.tree`
...................

- |Fix| Prevents :func:`tree.plot_tree` from drawing out of the boundary of
the figure. :pr:`21917` by `Thomas Fan`_.

:mod:`sklearn.utils`
....................

Expand Down
30 changes: 20 additions & 10 deletions examples/tree/plot_iris_dtc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
================================================================
Plot the decision surface of a decision tree on the iris dataset
================================================================
=======================================================================
Plot the decision surface of decision trees trained on the iris dataset
=======================================================================

Plot the decision surface of a decision tree trained on pairs
of features of the iris dataset.
Expand All @@ -14,20 +14,24 @@

We also show the tree structure of a model built on all of the features.
"""
# %%
# First load the copy of the Iris dataset shipped with scikit-learn:
from sklearn.datasets import load_iris

iris = load_iris()


# %%
# Display the decision functions of trees trained on all pairs of features.
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.tree import DecisionTreeClassifier

# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02

# Load data
iris = load_iris()

for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
# We only take the two corresponding features
Expand Down Expand Up @@ -67,11 +71,17 @@
s=15,
)

plt.suptitle("Decision surface of a decision tree using paired features")
plt.suptitle("Decision surface of decision trees trained on pairs of features")
plt.legend(loc="lower right", borderpad=0, handletextpad=0)
plt.axis("tight")
_ = plt.axis("tight")

# %%
# Display the structure of a single decision tree trained on all the features
# together.
from sklearn.tree import plot_tree

plt.figure()
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
plot_tree(clf, filled=True)
plt.title("Decision tree trained on all the iris features")
plt.show()
19 changes: 9 additions & 10 deletions sklearn/tree/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,7 @@ def export(self, decision_tree, ax=None):

scale_x = ax_width / max_x
scale_y = ax_height / max_y

self.recurse(draw_tree, decision_tree.tree_, ax, scale_x, scale_y, ax_height)
self.recurse(draw_tree, decision_tree.tree_, ax, max_x, max_y)

anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)]

Expand All @@ -693,15 +692,15 @@ def export(self, decision_tree, ax=None):

return anns

def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
def recurse(self, node, tree, ax, max_x, max_y, depth=0):
import matplotlib.pyplot as plt

kwargs = dict(
bbox=self.bbox_args.copy(),
ha="center",
va="center",
zorder=100 - 10 * depth,
xycoords="axes points",
xycoords="axes fraction",
arrowprops=self.arrow_args.copy(),
)
kwargs["arrowprops"]["edgecolor"] = plt.rcParams["text.color"]
Expand All @@ -710,7 +709,7 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
kwargs["fontsize"] = self.fontsize

# offset things by .5 to center them in plot
xy = ((node.x + 0.5) * scale_x, height - (node.y + 0.5) * scale_y)
xy = ((node.x + 0.5) / max_x, (max_y - node.y - 0.5) / max_y)

if self.max_depth is None or depth <= self.max_depth:
if self.filled:
Expand All @@ -723,17 +722,17 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
ax.annotate(node.tree.label, xy, **kwargs)
else:
xy_parent = (
(node.parent.x + 0.5) * scale_x,
height - (node.parent.y + 0.5) * scale_y,
(node.parent.x + 0.5) / max_x,
(max_y - node.parent.y - 0.5) / max_y,
)
ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
for child in node.children:
self.recurse(child, tree, ax, scale_x, scale_y, height, depth=depth + 1)
self.recurse(child, tree, ax, max_x, max_y, depth=depth + 1)

else:
xy_parent = (
(node.parent.x + 0.5) * scale_x,
height - (node.parent.y + 0.5) * scale_y,
(node.parent.x + 0.5) / max_x,
(max_y - node.parent.y - 0.5) / max_y,
)
kwargs["bbox"]["fc"] = "grey"
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
Expand Down
0