From 448a3d5fb2f2f1a4502a75bff74b7da9192cd866 Mon Sep 17 00:00:00 2001 From: Zero Date: Thu, 17 Dec 2020 10:35:34 +0800 Subject: [PATCH 01/15] FIX add fontname argument in plot_tree for non-English characters --- sklearn/tree/_export.py | 31 ++++++++++++++++++------------- sklearn/tree/tests/test_export.py | 8 ++++---- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 9ef4faeb0f56f..78ed9cd4eb717 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -197,9 +197,9 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, class _BaseTreeExporter: def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - impurity=True, node_ids=False, - proportion=False, rotate=False, rounded=False, - precision=3, fontsize=None): + impurity=True, node_ids=False, proportion=False, + rotate=False, rounded=False, precision=3, + fontsize=None, fontname='helvetica'): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names @@ -212,6 +212,7 @@ def __init__(self, max_depth=None, feature_names=None, self.rounded = rounded self.precision = precision self.fontsize = fontsize + self.fontname = fontname def get_color(self, value): # Find the appropriate color & intensity for a node @@ -371,15 +372,15 @@ def __init__(self, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - special_characters=False, precision=3): + special_characters=False, precision=3, fontname='helvetica'): super().__init__( max_depth=max_depth, feature_names=feature_names, - class_names=class_names, label=label, filled=filled, - impurity=impurity, - node_ids=node_ids, proportion=proportion, rotate=rotate, - rounded=rounded, - precision=precision) + class_names=class_names, label=label, + filled=filled, impurity=impurity, + node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, + precision=precision, fontname=fontname) self.leaves_parallel = leaves_parallel self.out_file = out_file self.special_characters = special_characters @@ -450,7 +451,7 @@ def head(self): ', style="%s", color="black"' % ", ".join(rounded_filled)) if self.rounded: - self.out_file.write(', fontname=helvetica') + self.out_file.write(', fontname="%s"' % self.fontname) self.out_file.write('] ;\n') # Specify graph & edge aesthetics @@ -458,7 +459,7 @@ def head(self): self.out_file.write( 'graph [ranksep=equally, splines=polyline] ;\n') if self.rounded: - self.out_file.write('edge [fontname=helvetica] ;\n') + self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname) if self.rotate: self.out_file.write('rankdir=LR ;\n') @@ -667,7 +668,8 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, - rounded=False, special_characters=False, precision=3): + rounded=False, special_characters=False, precision=3, + fontname='helvetica'): """Export a decision tree in DOT format. This function generates a GraphViz representation of the decision tree, @@ -745,6 +747,9 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node. + fontname : str, default='helvetica' + Name of font used to render text. + Returns ------- dot_data : string @@ -784,7 +789,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, filled=filled, leaves_parallel=leaves_parallel, impurity=impurity, node_ids=node_ids, proportion=proportion, rotate=rotate, rounded=rounded, special_characters=special_characters, - precision=precision) + precision=precision, fontname=fontname) exporter.export(decision_tree) if return_string: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index f12f1daeb57c1..bc4ab4619dc5c 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -87,8 +87,8 @@ def test_graphviz_toy(): rounded=True, out_file=None) contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname=helvetica] ;\n' \ - 'edge [fontname=helvetica] ;\n' \ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label=0 ≤ 0.0
samples = 100.0%
' \ 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \ '1 [label=value = [1.0, 0.0]>, ' \ @@ -177,9 +177,9 @@ def test_graphviz_toy(): out_file=None, rotate=True, rounded=True) contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname=helvetica] ;\n' \ + 'fontname="helvetica"] ;\n' \ 'graph [ranksep=equally, splines=polyline] ;\n' \ - 'edge [fontname=helvetica] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ 'rankdir=LR ;\n' \ '0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \ 'value = 0.0", fillcolor="#f2c09c"] ;\n' \ From 2cd03cb074271f4bf6ed7854518c177c19dcb0d8 Mon Sep 17 00:00:00 2001 From: Zero Date: Thu, 17 Dec 2020 10:56:31 +0800 Subject: [PATCH 02/15] Update v0.24.rst --- doc/whats_new/v0.24.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index cd65b584010c1..ac47746e4f395 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -785,6 +785,10 @@ Changelog - |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib configuration settings. :pr:`17187` by `Andreas Müller`_. +- |Fix| :func:`tree.export_graphviz` + Add fontname argument in plot_tree for non-English characters. + :pr:`18959` by :user:`Zero ` and :user:`wstates `. + - |API| The parameter ``X_idx_sorted`` is now deprecated in :meth:`tree.DecisionTreeClassifier.fit` and :meth:`tree.DecisionTreeRegressor.fit`, and has not effect. From bb6e6030a833510a81049405dab2cbbc0bfe66f2 Mon Sep 17 00:00:00 2001 From: Zero Date: Thu, 17 Dec 2020 17:44:49 +0800 Subject: [PATCH 03/15] Update v0.24.rst --- doc/whats_new/v0.24.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index ac47746e4f395..21bf6491aa3ca 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -785,9 +785,9 @@ Changelog - |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib configuration settings. :pr:`17187` by `Andreas Müller`_. -- |Fix| :func:`tree.export_graphviz` - Add fontname argument in plot_tree for non-English characters. - :pr:`18959` by :user:`Zero ` and :user:`wstates `. +- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` for + non-English characters. :pr:`18959` by :user:`Zero ` and + :user:`wstates `. - |API| The parameter ``X_idx_sorted`` is now deprecated in :meth:`tree.DecisionTreeClassifier.fit` and From f4741ecdaaab539a517b04a71700640e0e118602 Mon Sep 17 00:00:00 2001 From: Zero Date: Fri, 18 Dec 2020 08:58:05 +0800 Subject: [PATCH 04/15] Update v1.0.rst --- doc/whats_new/v1.0.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 211f1e4049d65..d4c8e2958aacb 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -44,7 +44,12 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.tree` +................... +- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` for + non-English characters. :pr:`18959` by :user:`Zero ` and + :user:`wstates `. Code and Documentation Contributors ----------------------------------- From 8aaede39ef3a3f15ff156d1b352efe037a5eadd5 Mon Sep 17 00:00:00 2001 From: Zero Date: Fri, 18 Dec 2020 10:19:24 +0800 Subject: [PATCH 05/15] FIX add fontname argument in plot_tree for non-English characters --- sklearn/tree/_export.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 78ed9cd4eb717..a404875780bc9 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -81,9 +81,9 @@ def __repr__(self): @_deprecate_positional_args def plot_tree(decision_tree, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - impurity=True, node_ids=False, - proportion=False, rotate='deprecated', rounded=False, - precision=3, ax=None, fontsize=None): + impurity=True, node_ids=False, proportion=False, + rotate='deprecated', rounded=False, precision=3, + ax=None, fontsize=None, fontname='helvetica'): """Plot a decision tree. The sample counts that are shown are weighted with any sample_weights that @@ -158,6 +158,9 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, fontsize : int, default=None Size of text font. If None, determined automatically to fit figure. + fontname : str, default='helvetica' + Name of font used to render text. + Returns ------- annotations : list of artists @@ -188,9 +191,9 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, exporter = _MPLTreeExporter( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - impurity=impurity, node_ids=node_ids, - proportion=proportion, rotate=rotate, rounded=rounded, - precision=precision, fontsize=fontsize) + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision, + fontsize=fontsize, fontname=fontname) return exporter.export(decision_tree, ax=ax) @@ -530,13 +533,15 @@ def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - precision=3, fontsize=None): + precision=3, fontsize=None, fontname='helvetica'): super().__init__( max_depth=max_depth, feature_names=feature_names, - class_names=class_names, label=label, filled=filled, - impurity=impurity, node_ids=node_ids, proportion=proportion, - rotate=rotate, rounded=rounded, precision=precision) + class_names=class_names, label=label, + filled=filled, impurity=impurity, + node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, + precision=precision, fontname=fontname) self.fontsize = fontsize # validate From 01a2ef7545615c3df1b1c7bb6397c6e337cd7f0a Mon Sep 17 00:00:00 2001 From: Zero Date: Fri, 18 Dec 2020 10:50:41 +0800 Subject: [PATCH 06/15] Add changelog for #18959 --- doc/whats_new/v0.24.rst | 6 +++--- doc/whats_new/v1.0.rst | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 8b436e48168f9..761baac1ce67f 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -792,9 +792,9 @@ Changelog - |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib configuration settings. :pr:`17187` by `Andreas Müller`_. -- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` for - non-English characters. :pr:`18959` by :user:`Zero ` and - :user:`wstates `. +- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` and + :func:`tree.plot_tree` for non-English characters. + :pr:`18959` by :user:`Zero ` and :user:`wstates `. - |API| The parameter ``X_idx_sorted`` is now deprecated in :meth:`tree.DecisionTreeClassifier.fit` and diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d4c8e2958aacb..ed46258a44324 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -47,9 +47,9 @@ Changelog :mod:`sklearn.tree` ................... -- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` for - non-English characters. :pr:`18959` by :user:`Zero ` and - :user:`wstates `. +- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` and + :func:`tree.plot_tree` for non-English characters. + :pr:`18959` by :user:`Zero ` and :user:`wstates `. Code and Documentation Contributors ----------------------------------- From f3fa121d87870683cc0088154fedddff3f6b5d53 Mon Sep 17 00:00:00 2001 From: Zero Date: Fri, 18 Dec 2020 19:46:57 +0800 Subject: [PATCH 07/15] Update v0.24.rst --- doc/whats_new/v0.24.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 761baac1ce67f..a5b0ec36d62aa 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -792,10 +792,6 @@ Changelog - |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib configuration settings. :pr:`17187` by `Andreas Müller`_. -- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` and - :func:`tree.plot_tree` for non-English characters. - :pr:`18959` by :user:`Zero ` and :user:`wstates `. - - |API| The parameter ``X_idx_sorted`` is now deprecated in :meth:`tree.DecisionTreeClassifier.fit` and :meth:`tree.DecisionTreeRegressor.fit`, and has not effect. From 02ba737fb517f81b7899b981a6ad53851a66d9a1 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 23 Dec 2020 09:39:55 +0800 Subject: [PATCH 08/15] overwrite the default fontname to test the export_graphviz function --- sklearn/tree/tests/test_export.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index bc4ab4619dc5c..4e638ddf22f49 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -84,11 +84,11 @@ def test_graphviz_toy(): # Test plot_options contents1 = export_graphviz(clf, filled=True, impurity=False, proportion=True, special_characters=True, - rounded=True, out_file=None) + rounded=True, out_file=None, fontname="sans") contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname="helvetica"] ;\n' \ - 'edge [fontname="helvetica"] ;\n' \ + 'fontname="sans"] ;\n' \ + 'edge [fontname="sans"] ;\n' \ '0 [label=0 ≤ 0.0
samples = 100.0%
' \ 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n' \ '1 [label=value = [1.0, 0.0]>, ' \ @@ -174,12 +174,13 @@ def test_graphviz_toy(): clf.fit(X, y) contents1 = export_graphviz(clf, filled=True, leaves_parallel=True, - out_file=None, rotate=True, rounded=True) + out_file=None, rotate=True, rounded=True, + fontname="sans") contents2 = 'digraph Tree {\n' \ 'node [shape=box, style="filled, rounded", color="black", ' \ - 'fontname="helvetica"] ;\n' \ + 'fontname="sans"] ;\n' \ 'graph [ranksep=equally, splines=polyline] ;\n' \ - 'edge [fontname="helvetica"] ;\n' \ + 'edge [fontname="sans"] ;\n' \ 'rankdir=LR ;\n' \ '0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \ 'value = 0.0", fillcolor="#f2c09c"] ;\n' \ From 384e63fe2d963a0e3433865f0716c179b1311f28 Mon Sep 17 00:00:00 2001 From: Zero Date: Tue, 5 Jan 2021 09:35:29 +0800 Subject: [PATCH 09/15] Remove fontname for tree.plot_tree --- doc/whats_new/v1.0.rst | 6 +++--- sklearn/tree/_export.py | 25 ++++++++++--------------- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index dd5c95b1662cc..71f41d59c6f0b 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -47,9 +47,9 @@ Changelog :mod:`sklearn.tree` ................... -- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` and - :func:`tree.plot_tree` for non-English characters. - :pr:`18959` by :user:`Zero ` and :user:`wstates `. +- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` + for non-English characters. :pr:`18959` by :user:`Zero ` + and :user:`wstates `. :mod:`sklearn.cluster` ...................... diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index aae8ce6e1842c..96551c46ed6be 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -81,9 +81,9 @@ def __repr__(self): @_deprecate_positional_args def plot_tree(decision_tree, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - impurity=True, node_ids=False, proportion=False, - rotate='deprecated', rounded=False, precision=3, - ax=None, fontsize=None, fontname='helvetica'): + impurity=True, node_ids=False, + proportion=False, rotate='deprecated', rounded=False, + precision=3, ax=None, fontsize=None): """Plot a decision tree. The sample counts that are shown are weighted with any sample_weights that @@ -158,9 +158,6 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, fontsize : int, default=None Size of text font. If None, determined automatically to fit figure. - fontname : str, default='helvetica' - Name of font used to render text. - Returns ------- annotations : list of artists @@ -191,9 +188,9 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, exporter = _MPLTreeExporter( max_depth=max_depth, feature_names=feature_names, class_names=class_names, label=label, filled=filled, - impurity=impurity, node_ids=node_ids, proportion=proportion, - rotate=rotate, rounded=rounded, precision=precision, - fontsize=fontsize, fontname=fontname) + impurity=impurity, node_ids=node_ids, + proportion=proportion, rotate=rotate, rounded=rounded, + precision=precision, fontsize=fontsize) return exporter.export(decision_tree, ax=ax) @@ -533,15 +530,13 @@ def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, - precision=3, fontsize=None, fontname='helvetica'): + precision=3, fontsize=None): super().__init__( max_depth=max_depth, feature_names=feature_names, - class_names=class_names, label=label, - filled=filled, impurity=impurity, - node_ids=node_ids, proportion=proportion, - rotate=rotate, rounded=rounded, - precision=precision, fontname=fontname) + class_names=class_names, label=label, filled=filled, + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision) self.fontsize = fontsize # validate From 74bd42a056e355a0da2e48372c86437d695f2f47 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 6 Jan 2021 09:11:28 +0800 Subject: [PATCH 10/15] Doc for tree.export_graphviz --- sklearn/tree/_export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 96551c46ed6be..ae7cbc269e65a 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -736,8 +736,8 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, When set to ``True``, orient tree left to right rather than top-down. rounded : bool, default=False - When set to ``True``, draw node boxes with rounded corners and use - Helvetica fonts instead of Times-Roman. + When set to ``True``, draw node boxes with rounded corners and + could use fonts via `fontname` instead of Times-Roman. special_characters : bool, default=False When set to ``False``, ignore special characters for PostScript @@ -748,7 +748,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, impurity, threshold and value attributes of each node. fontname : str, default='helvetica' - Name of font used to render text. + Name of font used to render text. Only used when `rounded=True` Returns ------- From b18c0fcc5b8ba8cf35f591f557850c940442e4cd Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 6 Jan 2021 09:24:33 +0800 Subject: [PATCH 11/15] Reduce fontname usage range, only in _DOTTreeExporter --- sklearn/tree/_export.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index ae7cbc269e65a..2bfde71320583 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -197,9 +197,9 @@ def plot_tree(decision_tree, *, max_depth=None, feature_names=None, class _BaseTreeExporter: def __init__(self, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, - impurity=True, node_ids=False, proportion=False, - rotate=False, rounded=False, precision=3, - fontsize=None, fontname='helvetica'): + impurity=True, node_ids=False, + proportion=False, rotate=False, rounded=False, + precision=3, fontsize=None): self.max_depth = max_depth self.feature_names = feature_names self.class_names = class_names @@ -212,7 +212,6 @@ def __init__(self, max_depth=None, feature_names=None, self.rounded = rounded self.precision = precision self.fontsize = fontsize - self.fontname = fontname def get_color(self, value): # Find the appropriate color & intensity for a node @@ -376,14 +375,13 @@ def __init__(self, out_file=SENTINEL, max_depth=None, super().__init__( max_depth=max_depth, feature_names=feature_names, - class_names=class_names, label=label, - filled=filled, impurity=impurity, - node_ids=node_ids, proportion=proportion, - rotate=rotate, rounded=rounded, - precision=precision, fontname=fontname) + class_names=class_names, label=label, filled=filled, + impurity=impurity, node_ids=node_ids, proportion=proportion, + rotate=rotate, rounded=rounded, precision=precision) self.leaves_parallel = leaves_parallel self.out_file = out_file self.special_characters = special_characters + self.fontname = fontname # PostScript compatibility for special characters if special_characters: From 3681b683edbaf25bf764c62ffa783e8980e2b86e Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 6 Jan 2021 09:52:25 +0800 Subject: [PATCH 12/15] Fix fontname only used when rounded=True in tree.export_graphviz --- sklearn/tree/_export.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/tree/_export.py b/sklearn/tree/_export.py index 2bfde71320583..ff29790e3699e 100644 --- a/sklearn/tree/_export.py +++ b/sklearn/tree/_export.py @@ -448,16 +448,17 @@ def head(self): self.out_file.write( ', style="%s", color="black"' % ", ".join(rounded_filled)) - if self.rounded: - self.out_file.write(', fontname="%s"' % self.fontname) + + self.out_file.write(', fontname="%s"' % self.fontname) self.out_file.write('] ;\n') # Specify graph & edge aesthetics if self.leaves_parallel: self.out_file.write( 'graph [ranksep=equally, splines=polyline] ;\n') - if self.rounded: - self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname) + + self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname) + if self.rotate: self.out_file.write('rankdir=LR ;\n') @@ -734,8 +735,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, When set to ``True``, orient tree left to right rather than top-down. rounded : bool, default=False - When set to ``True``, draw node boxes with rounded corners and - could use fonts via `fontname` instead of Times-Roman. + When set to ``True``, draw node boxes with rounded corners. special_characters : bool, default=False When set to ``False``, ignore special characters for PostScript @@ -746,7 +746,7 @@ def export_graphviz(decision_tree, out_file=None, *, max_depth=None, impurity, threshold and value attributes of each node. fontname : str, default='helvetica' - Name of font used to render text. Only used when `rounded=True` + Name of font used to render text. Returns ------- From c115dcf26fd24142c72fa74a4b4d766e361cefe0 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 6 Jan 2021 10:47:21 +0800 Subject: [PATCH 13/15] Update test_graphviz_toy for fontname --- sklearn/tree/tests/test_export.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index e669e849c38f3..6a1c3bf506495 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -33,7 +33,8 @@ def test_graphviz_toy(): # Test export code contents1 = export_graphviz(clf, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \ @@ -50,7 +51,8 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, feature_names=["feature0", "feature1"], out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n' \ @@ -66,7 +68,8 @@ def test_graphviz_toy(): # Test with class_names contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]\\nclass = yes"] ;\n' \ '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n' \ @@ -107,7 +110,8 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box] ;\n' \ + 'node [shape=box, fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n' \ 'value = [3, 3]\\nclass = y[0]"] ;\n' \ '1 [label="(...)"] ;\n' \ @@ -122,7 +126,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, max_depth=0, filled=True, out_file=None, node_ids=True) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", '\ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n' \ 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n' \ '1 [label="(...)", fillcolor="#C0C0C0"] ;\n' \ @@ -204,7 +210,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, filled=True, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", '\ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", ' \ 'fillcolor="#ffffff"] ;\n' \ '}' From de4063f391e0ccfef1d25a482a9d133d258cd858 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 6 Jan 2021 11:09:04 +0800 Subject: [PATCH 14/15] Update test_graphviz_toy for fontname --- sklearn/tree/tests/test_export.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 6a1c3bf506495..6a7bf33b2143f 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -149,7 +149,9 @@ def test_graphviz_toy(): contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None) contents2 = 'digraph Tree {\n' \ - 'node [shape=box, style="filled", color="black"] ;\n' \ + 'node [shape=box, style="filled", color="black", ' \ + 'fontname="helvetica"] ;\n' \ + 'edge [fontname="helvetica"] ;\n' \ '0 [label="X[0] <= 0.0\\nsamples = 6\\n' \ 'value = [[3.0, 1.5, 0.0]\\n' \ '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n' \ From e6954054afdc974fb0f5e13915e12444531d16c7 Mon Sep 17 00:00:00 2001 From: Zero Date: Thu, 7 Jan 2021 09:04:53 +0800 Subject: [PATCH 15/15] ENH add fontname in export_graphviz for non-English characters --- doc/whats_new/v1.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 71f41d59c6f0b..f056f3a9be2d3 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -47,7 +47,7 @@ Changelog :mod:`sklearn.tree` ................... -- |Fix| Add `fontname` argument in :func:`tree.export_graphviz` +- |Enhancement| Add `fontname` argument in :func:`tree.export_graphviz` for non-English characters. :pr:`18959` by :user:`Zero ` and :user:`wstates `.