8000 ENH add fontname argument in export_graphviz for non-English characte… · scikit-learn/scikit-learn@34de1b9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 34de1b9

Browse files
authored
ENH add fontname argument in export_graphviz for non-English characters (#18959)
1 parent 1e46db6 commit 34de1b9

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ Changelog
4444
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
4545
where 123456 is the *pull request* number, not the issue number.
4646
47+
:mod:`sklearn.tree`
48+
...................
49+
50+
- |Enhancement| Add `fontname` argument in :func:`tree.export_graphviz`
51+
for non-English characters. :pr:`18959` by :user:`Zero <Zeroto521>`
52+
and :user:`wstates <wstates>`.
53+
4754
:mod:`sklearn.cluster`
4855
......................
4956

sklearn/tree/_export.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,18 +371,17 @@ def __init__(self, out_file=SENTINEL, max_depth=None,
371371
feature_names=None, class_names=None, label='all',
372372
filled=False, leaves_parallel=False, impurity=True,
373373
node_ids=False, proportion=False, rotate=False, rounded=False,
374-
special_characters=False, precision=3):
374+
special_characters=False, precision=3, fontname='helvetica'):
375375

376376
super().__init__(
377377
max_depth=max_depth, feature_names=feature_names,
378378
class_names=class_names, label=label, filled=filled,
379-
impurity=impurity,
380-
node_ids=node_ids, proportion=proportion, rotate=rotate,
381-
rounded=rounded,
382-
precision=precision)
379+
impurity=impurity, node_ids=node_ids, proportion=proportion,
380+
rotate=rotate, rounded=rounded, precision=precision)
383381
self.leaves_parallel = leaves_parallel
384382
self.out_file = out_file
385383
self.special_characters = special_characters
384+
self.fontname = fontname
386385

387386
# PostScript compatibility for special characters
388387
if special_characters:
@@ -449,16 +448,17 @@ def head(self):
449448
self.out_file.write(
450449
', style="%s", color="black"'
451450
% ", ".join(rounded_filled))
452-
if self.rounded:
453-
self.out_file.write(', fontname=helvetica')
451+
452+
self.out_file.write(', fontname="%s"' % self.fontname)
454453
self.out_file.write('] ;\n')
455454

456455
# Specify graph & edge aesthetics
457456
if self.leaves_parallel:
458457
self.out_file.write(
459458
'graph [ranksep=equally, splines=polyline] ;\n')
460-
if self.rounded:
461-
self.out_file.write('edge [fontname=helvetica] ;\n')
459+
460+
self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname)
461+
462462
if self.rotate:
463463
self.out_file.write('rankdir=LR ;\n')
464464

@@ -667,7 +667,8 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
667667
feature_names=None, class_names=None, label='all',
668668
filled=False, leaves_parallel=False, impurity=True,
669669
node_ids=False, proportion=False, rotate=False,
670-
rounded=False, special_characters=False, precision=3):
670+
rounded=False, special_characters=False, precision=3,
671+
fontname='helvetica'):
671672
"""Export a decision tree in DOT format.
672673
673674
This function generates a GraphViz representation of the decision tree,
@@ -734,8 +735,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
734735
When set to ``True``, orient tree left to right rather than top-down.
735736
736737
rounded : bool, default=False
737-
When set to ``True``, draw node boxes with rounded corners and use
738-
Helvetica fonts instead of Times-Roman.
738+
When set to ``True``, draw node boxes with rounded corners.
739739
740740
special_characters : bool, default=False
741741
When set to ``False``, ignore special characters for PostScript
@@ -745,6 +745,9 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
745745
Number of digits of precision for floating point in the values of
746746
impurity, threshold and value attributes of each node.
747747
748+
fontname : str, default='helvetica'
749+
Name of font used to render text.
750+
748751
Returns
749752
-------
750753
dot_data : string
@@ -784,7 +787,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
784787
filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
785788
node_ids=node_ids, proportion=proportion, rotate=rotate,
786789
rounded=rounded, special_characters=special_characters,
787-
precision=precision)
790+
precision=precision, fontname=fontname)
788791
exporter.export(decision_tree)
789792

790793
if return_string:

sklearn/tree/tests/test_export.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def test_graphviz_toy():
3333
# Test export code
3434
contents1 = export_graphviz(clf, out_file=None)
3535
contents2 = 'digraph Tree {\n' \
36-
'node [shape=box] ;\n' \
36+
'node [shape=box, fontname="helvetica"] ;\n' \
37+
'edge [fontname="helvetica"] ;\n' \
3738
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
3839
'value = [3, 3]"] ;\n' \
3940
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
@@ -50,7 +51,8 @@ def test_graphviz_toy():
5051
contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"],
5152
out_file=None)
5253
contents2 = 'digraph Tree {\n' \
53-
'node [shape=box] ;\n' \
54+
'node [shape=box, fontname="helvetica"] ;\n' \
55+
'edge [fontname="helvetica"] ;\n' \
5456
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
5557
'value = [3, 3]"] ;\n' \
5658
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \
@@ -66,7 +68,8 @@ def test_graphviz_toy():
6668
# Test with class_names
6769
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
6870
contents2 = 'digraph Tree {\n' \
69-
'node [shape=box] ;\n' \
71+
'node [shape=box, fontname="helvetica"] ;\n' \
72+
'edge [fontname="helvetica"] ;\n' \
7073
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
7174
'value = [3, 3]\\nclass = yes"] ;\n' \
7275
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \
@@ -84,11 +87,11 @@ def test_graphviz_toy():
8487
# Test plot_options
8588
contents1 = export_graphviz(clf, filled=True, impurity=False,
8689
proportion=True, special_characters=True,
87-
rounded=True, out_file=None)
90+
rounded=True, out_file=None, fontname="sans")
8891
contents2 = 'digraph Tree {\n' \
8992
'node [shape=box, style="filled, rounded", color="black", ' \
90-
'fontname=helvetica] ;\n' \
91-
'edge [fontname=helvetica] ;\n' \
93+
'fontname="sans"] ;\n' \
94+
'edge [fontname="sans"] ;\n' \
9295
'0 [label=<X<SUB>0</SUB> &le; 0.0<br/>samples = 100.0%<br/>' \
9396
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \
9497
'1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, ' \
@@ -107,7 +110,8 @@ def test_graphviz_toy():
107110
contents1 = export_graphviz(clf, max_depth=0,
108111
class_names=True, out_file=None)
109112
contents2 = 'digraph Tree {\n' \
110-
'node [shape=box] ;\n' \
113+
'node [shape=box, fontname="helvetica"] ;\n' \
114+
'edge [fontname="helvetica"] ;\n' \
111115
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \
112116
'value = [3, 3]\\nclass = y[0]"] ;\n' \
113117
'1 [label="(...)"] ;\n' \
@@ -122,7 +126,9 @@ def test_graphviz_toy():
122126
contents1 = export_graphviz(clf, max_depth=0, filled=True,
123127
out_file=None, node_ids=True)
124128
contents2 = 'digraph Tree {\n' \
125-
'node [shape=box, style="filled", color="black"] ;\n' \
129+
'node [shape=box, style="filled", color="black", '\
130+
'fontname="helvetica"] ;\n' \
131+
'edge [fontname="helvetica"] ;\n' \
126132
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \
127133
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \
128134
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \
@@ -143,7 +149,9 @@ def test_graphviz_toy():
143149
contents1 = export_graphviz(clf, filled=True,
144150
impurity=False, out_file=None)
145151
contents2 = 'digraph Tree {\n' \
146-
'node [shape=box, style="filled", color="black"] ;\n' \
152+
'node [shape=box, style="filled", color="black", ' \
153+
'fontname="helvetica"] ;\n' \
154+
'edge [fontname="helvetica"] ;\n' \
147155
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
148156
'value = [[3.0, 1.5, 0.0]\\n' \
149157
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \
@@ -174,12 +182,13 @@ def test_graphviz_toy():
174182
clf.fit(X, y)
175183

176184
contents1 = export_graphviz(clf, filled=True, leaves_parallel=True,
177-
out_file=None, rotate=True, rounded=True)
185+
out_file=None, rotate=True, rounded=True,
186+
fontname="sans")
178187
contents2 = 'digraph Tree {\n' \
179188
'node [shape=box, style="filled, rounded", color="black", ' \
180-
'fontname=helvetica] ;\n' \
189+
'fontname="sans"] ;\n' \
181190
'graph [ranksep=equally, splines=polyline] ;\n' \
182-
'edge [fontname=helvetica] ;\n' \
191+
'edge [fontname="sans"] ;\n' \
183192
'rankdir=LR ;\n' \
184193
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
185194
'value = 0.0", fillcolor="#f2c09c"] ;\n' \
@@ -203,7 +212,9 @@ def test_graphviz_toy():
203212

204213
contents1 = export_graphviz(clf, filled=True, out_file=None)
205214
contents2 = 'digraph Tree {\n' \
206-
'node [shape=box, style="filled", color="black"] ;\n' \
215+
'node [shape=box, style="filled", color="black", '\
216+
'fontname="helvetica"] ;\n' \
217+
'edge [fontname="helvetica"] ;\n' \
207218
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \
208219
'fillcolor="#ffffff"] ;\n' \
209220
'}'

0 commit comments

Comments
 (0)
0