8000 FIX Fixes plot_tree from going out of bounds (#21917) · scikit-learn/scikit-learn@77fbdd1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 77fbdd1

Browse files
authored
FIX Fixes plot_tree from going out of bounds (#21917)
1 parent 02b41de commit 77fbdd1

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

doc/whats_new/v1.0.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ Changelog
6363
- |Fix| Fixes compatibility bug with NumPy 1.22 in :class:`preprocessing.OneHotEncoder`.
6464
:pr:`21517` by `Thomas Fan`_.
6565

66+
:mod:`sklearn.tree`
67+
...................
68+
69+
- |Fix| Prevents :func:`tree.plot_tree` from drawing out of the boundary of
70+
the figure. :pr:`21917` by `Thomas Fan`_.
71+
6672
:mod:`sklearn.utils`
6773
....................
6874

examples/tree/plot_iris_dtc.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
================================================================
3-
Plot the decision surface of a decision tree on the iris dataset
4-
================================================================
2+
=======================================================================
3+
Plot the decision surface of decision trees trained on the iris dataset
4+
=======================================================================
55
66
Plot the decision surface of a decision tree trained on pairs
77
of features of the iris dataset.
@@ -14,20 +14,24 @@
1414
1515
We also show the tree structure of a model built on all of the features.
1616
"""
17+
# %%
18+
# First load the copy of the Iris dataset shipped with scikit-learn:
19+
from sklearn.datasets import load_iris
20+
21+
iris = load_iris()
22+
1723

24+
# %%
25+
# Display the decision functions of trees trained on all pairs of features.
1826
import numpy as np
1927
import matplotlib.pyplot as plt
20-
21-
from sklearn.datasets import load_iris
22-
from sklearn.tree import DecisionTreeClassifier, plot_tree
28+
from sklearn.tree import DecisionTreeClassifier
2329

2430
# Parameters
2531
n_classes = 3
2632
plot_colors = "ryb"
2733
plot_step = 0.02
2834

29-
# Load data
30-
iris = load_iris()
3135

3236
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
3337
# We only take the two corresponding features
@@ -67,11 +71,17 @@
6771
s=15,
6872
)
6973

70-
plt.suptitle("Decision surface of a decision tree using paired features")
74+
plt.suptitle("Decision surface of decision trees trained on pairs of features")
7175
plt.legend(loc="lower right", borderpad=0, handletextpad=0)
72-
plt.axis("tight")
76+
_ = plt.axis("tight")
77+
78+
# %%
79+
# Display the structure of a single decision tree trained on all the features
80+
# together.
81+
from sklearn.tree import plot_tree
7382

7483
plt.figure()
7584
clf = DecisionTreeClassifier().fit(iris.data, iris.target)
7685
plot_tree(clf, filled=True)
86+
plt.title("Decision tree trained on all the iris features")
7787
plt.show()

sklearn/tree/_export.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -666,8 +666,7 @@ def export(self, decision_tree, ax=None):
666666

667667
scale_x = ax_width / max_x
668668
scale_y = ax_height / max_y
669-
670-
self.recurse(draw_tree, decision_tree.tree_, ax, scale_x, scale_y, ax_height)
669+
self.recurse(draw_tree, decision_tree.tree_, ax, max_x, max_y)
671670

672671
anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)]
673672

@@ -693,15 +692,15 @@ def export(self, decision_tree, ax=None):
693692

694693
return anns
695694

696-
def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
695+
def recurse(self, node, tree, ax, max_x, max_y, depth=0):
697696
import matplotlib.pyplot as plt
698697

699698
kwargs = dict(
700699
bbox=self.bbox_args.copy(),
701700
ha="center",
702701
va="center",
703702
zorder=100 - 10 * depth,
704-
xycoords="axes points",
703+
xycoords="axes fraction",
705704
arrowprops=self.arrow_args.copy(),
706705
)
707706
kwargs["arrowprops"]["edgecolor"] = plt.rcParams["text.color"]
@@ -710,7 +709,7 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
710709
kwargs["fontsize"] = self.fontsize
711710

712711
# offset things by .5 to center them in plot
713-
xy = ((node.x + 0.5) * scale_x, height - (node.y + 0.5) * scale_y)
712+
xy = ((node.x + 0.5) / max_x, (max_y - node.y - 0.5) / max_y)
714713

715714
if self.max_depth is None or depth <= self.max_depth:
716715
if self.filled:
@@ -723,17 +722,17 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
723722
ax.annotate(node.tree.label, xy, **kwargs)
724723
else:
725724
xy_parent = (
726-
(node.parent.x + 0.5) * scale_x,
727-
height - (node.parent.y + 0.5) * scale_y,
725+
(node.parent.x + 0.5) / max_x,
726+
(max_y - node.parent.y - 0.5) / max_y,
728727
)
729728
ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
730729
for child in node.children:
731-
self.recurse(child, tree, ax, scale_x, scale_y, height, depth=depth + 1)
730+
self.recurse(child, tree, ax, max_x, max_y, depth=depth + 1)
732731

733732
else:
734733
xy_parent = (
735-
(node.parent.x + 0.5) * scale_x,
736-
height - (node.parent.y + 0.5) * scale_y,
734+
(node.parent.x + 0.5) / max_x,
735+
(max_y - node.parent.y - 0.5) / max_y,
737736
)
738737
kwargs["bbox"]["fc"] = "grey"
739738
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)

0 commit comments

Comments
 (0)
0