diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 0372dcdd1fd4e..f056f3a9be2d3 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -44,6 +44,13 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.tree` +................... + +- |Enhancement| Add `fontname` argument in :func:`tree.export_graphviz` + for non-English characters. :pr:`18959` by :user:`Zero ` + and :user:`wstates `. + :mod:`sklearn.cluster` ...................... diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 1615c8eb15028..ff29790e3699e 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -371,18 +371,17 @@ def __init__(self, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3): + special_characters=False, precision=3, fontname='helvetica'): super().__init__( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - impurity=impurity, - node_ids=node_ids, proportion=proportion, rotate=rotate, - rounded=rounded, - precision=precision) + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision) self.leaves_parallel = leaves_parallel self.out_file = out_file self.special_characters = special_characters + self.fontname = fontname # PostScript compatibility for special characters if special_characters: @@ -449,16 +448,17 @@ def head(self): self.out_file.write( ', style="%s", color="black"' % ", ".join(rounded_filled)) - if self.rounded: - self.out_file.write(', fontname=helvetica') + + self.out_file.write(', fontname="%s"' % self.fontname) self.out_file.write('] ;\n') # Specify graph & edge aesthetics if self.leaves_parallel: self.out_file.write( 'graph [ranksep=equally, splines=polyline] ;\n') - if self.rounded: - self.out_file.write('edge [fontname=helvetica] ;\n') + + self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname) + if self.rotate: self.out_file.write('rankdir=LR ;\n') @@ -667,7 +667,8 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, - rounded=False, special_characters=False, precision=3): + rounded=False, special_characters=False, precision=3, + fontname='helvetica'): """Export a decision tree in DOT format. This function generates a GraphViz representation of the decision tree, @@ -734,8 +735,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, When set to ``True``, orient tree left to right rather than top-down. rounded : bool, default=False - When set to ``True``, draw node boxes with rounded corners and use - Helvetica fonts instead of Times-Roman. + When set to ``True``, draw node boxes with rounded corners. special_characters : bool, default=False When set to ``False``, ignore special characters for PostScript @@ -745,6 +745,9 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node. + fontname : str, default='helvetica' + Name of font used to render text. + Returns ------- dot_data : string @@ -784,7 +787,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, special_characters=special_characters, - precision=precision) + precision=precision, fontname=fontname) exporter.export(decision_tree) if return_string: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index a1b04e171e59a..6a7bf33b2143f 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -33,7 +33,8 @@ def test_graphviz_toy(): # Test export code contents1 = export_graphviz(clf, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \ @@ -50,7 +51,8 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"], out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \ @@ -66,7 +68,8 @@ def test_graphviz_toy(): # Test with class_names contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]\\nclass = yes"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \ @@ -84,11 +87,11 @@ def test_graphviz_toy(): # Test plot_options contents1 = export_graphviz(clf, filled=True, impurity=False, proportion=True, special_characters=True, - rounded=True, out_file=None) + rounded=True, out_file=None, fontname="sans") contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname=helvetica] ;\n' \ - 'edge [fontname=helvetica] ;\n' \ + 'fontname="sans"] ;\n' \ + 'edge [fontname="sans"] ;\n' \ '0 [label=0 ≤ 0.0
samples = 100.0%
' \ 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \ '1 [label=value = [1.0, 0.0]>, ' \ @@ -107,7 +110,8 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]\\nclass = y[0]"] ;\n' \ '1 [label="(...)"] ;\n' \ @@ -122,7 +126,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, max_depth=0, filled=True, out_file=None, node_ids=True) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", '\ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \ 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \ '1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ @@ -143,7 +149,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", ' \ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\nsamples = 6\\n' \ 'value = [[3.0, 1.5, 0.0]\\n' \ '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \ @@ -174,12 +182,13 @@ def test_graphviz_toy(): clf.fit(X, y) contents1 = export_graphviz(clf, filled=True, leaves_parallel=True, - out_file=None, rotate=True, rounded=True) + out_file=None, rotate=True, rounded=True, + fontname="sans") contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname=helvetica] ;\n' \ + 'fontname="sans"] ;\n' \ 'graph [ranksep=equally, splines=polyline] ;\n' \ - 'edge [fontname=helvetica] ;\n' \ + 'edge [fontname="sans"] ;\n' \ 'rankdir=LR ;\n' \ '0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \ 'value = 0.0", fillcolor="#f2c09c"] ;\n' \ @@ -203,7 +212,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, filled=True, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", '\ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \ 'fillcolor="#ffffff"] ;\n' \ '}'