8000 ENH Adds HTML visualizations for estimators (#14180) · scikit-learn/scikit-learn@ee2508c · GitHub
[go: up one dir, main page]

Skip to content

Commit ee2508c

Browse files
authored
ENH Adds HTML visualizations for estimators (#14180)
1 parent b9403f6 commit ee2508c

File tree

15 files changed

+732
-5
lines changed

15 files changed

+732
-5
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,7 @@ Plotting
15691569
utils.deprecated
15701570
utils.estimator_checks.check_estimator
15711571
utils.estimator_checks.parametrize_with_checks
1572+
utils.estimator_html_repr
15721573
utils.extmath.safe_sparse_dot
15731574
utils.extmath.randomized_range_finder
15741575
utils.extmath.randomized_svd

doc/modules/compose.rst

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,31 @@ above example would be::
528528
('countvectorizer', CountVectorizer(),
529529
'title')])
530530

531+
.. _visualizing_composite_estimators:
532+
533+
Visualizing Composite Estimators
534+
================================
535+
536+
Estimators can be displayed with a HTML representation when shown in a
537+
jupyter notebook. This can be useful to diagnose or visualize a Pipeline with
538+
many estimators. This visualization is activated by setting the
539+
`display` option in :func:`sklearn.set_config`::
540+
541+
>>> from sklearn import set_config
542+
>>> set_config(display='diagram') # doctest: +SKIP
543+
>>> # diplays HTML representation in a jupyter context
544+
>>> column_trans # doctest: +SKIP
545+
546+
An example of the HTML output can be seen in the
547+
**HTML representation of Pipeline** section of
548+
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
549+
As an alternative, the HTML can be written to a file using
550+
:func:`~sklearn.utils.estimator_html_repr`::
551+
552+
>>> from sklearn.utils import estimator_html_repr
553+
>>> with open('my_estimator.html', 'w') as f: # doctest: +SKIP
554+
... f.write(estimator_html_repr(clf))
555+
531556
.. topic:: Examples:
532557

533558
* :ref:`sphx_glr_auto_examples_compose_plot_column_transformer.py`

doc/whats_new/v0.23.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ Changelog
567567
:mod:`sklearn.utils`
568568
....................
569569

570+
- |Feature| Adds :func:`utils.estimator_html_repr` for returning a
571+
HTML representation of an estimator. :pr:`14180` by `Thomas Fan`_.
572+
570573
- |Enhancement| improve error message in :func:`utils.validation.column_or_1d`.
571574
:pr:`15926` by :user:`Loïc Estève <lesteve>`.
572575

@@ -605,6 +608,11 @@ Changelog
605608
Miscellaneous
606609
.............
607610

611+
- |MajorFeature| Adds a HTML representation of estimators to be shown in
612+
a jupyter notebook or lab. This visualization is acitivated by setting the
613+
`display` option in :func:`sklearn.set_config`. :pr:`14180` by
614+
`Thomas Fan`_.
615+
608616
- |Enhancement| ``scikit-learn`` now works with ``mypy`` without errors.
609617
:pr:`16726` by `Roman Yurchak`_.
610618

examples/compose/plot_column_transformer_mixed_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@
8787
clf.fit(X_train, y_train)
8888
print("model score: %.3f" % clf.score(X_test, y_test))
8989

90+
##############################################################################
91+
# HTML representation of ``Pipeline``
92+
###############################################################################
93+
# When the ``Pipeline`` is printed out in a jupyter notebook an HTML
94+
# representation of the estimator is displayed as follows:
95+
from sklearn import set_config
96+
set_config(display='diagram')
97+
clf
98+
9099
###############################################################################
91100
# Use ``ColumnTransformer`` by selecting column by data types
92101
###############################################################################

sklearn/_config.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)),
88
'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)),
99
'print_changed_only': True,
10+
'display': 'text',
1011
}
1112

1213

@@ -27,7 +28,7 @@ def get_config():
2728

2829

2930
def set_config(assume_finite=None, working_memory=None,
30-
print_changed_only=None):
31+
print_changed_only=None, display=None):
3132
"""Set global scikit-learn configuration
3233
3334
.. versionadded:: 0.19
@@ -59,6 +60,13 @@ def set_config(assume_finite=None, working_memory=None,
5960
6061
.. versionadded:: 0.21
6162
63+
display : {'text', 'diagram'}, optional
64+
If 'diagram', estimators will be displayed as text in a jupyter lab
65+
of notebook context. If 'text', estimators will be displayed as
66+
text. Default is 'text'.
67+
68+
.. versionadded:: 0.23
69+
6270
See Also
6371
--------
6472
config_context: Context manager for global scikit-learn configuration
@@ -70,6 +78,8 @@ def set_config(assume_finite=None, working_memory=None,
7078
_global_config['working_memory'] = working_memory
7179
if print_changed_only is not None:
7280
_global_config['print_changed_only'] = print_changed_only
81+
if display is not None:
82+
_global_config['display'] = display
7383

7484

7585
@contextmanager
@@ -100,6 +110,13 @@ def config_context(**new_config):
100110
.. versionchanged:: 0.23
101111
Default changed from False to True.
102112
113+
display : {'text', 'diagram'}, optional
114+
If 'diagram', estimators will be displayed as text in a jupyter lab
115+
of notebook context. If 'text', estimators will be displayed as
116+
text. Default is 'text'.
117+
118+
.. versionadded:: 0.23
119+
103120
Notes
104121
-----
105122
All settings, not just those presently modified, will be returned to

sklearn/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import numpy as np
1818

1919
from . import __version__
20+
from ._config import get_config
2021
from .utils import _IS_32BIT
2122
from .utils.validation import check_X_y
2223
from .utils.validation import check_array
24+
from .utils._estimator_html_repr import estimator_html_repr
2325
from .utils.validation import _deprecate_positional_args
2426

2527
_DEFAULT_TAGS = {
@@ -435,6 +437,17 @@ def _validate_data(self, X, y=None, reset=True,
435437

436438
return out
437439

440+
def _repr_html_(self):
441+
"""HTML representation of estimator"""
442+
return estimator_html_repr(self)
443+
444+
def _repr_mimebundle_(self, **kwargs):
445+
"""Mime bundle used by jupyter kernels to display estimator"""
446+
output = {"text/plain": repr(self)}
447+
if get_config()["display"] == 'diagram':
448+
output["text/html"] = estimator_html_repr(self)
449+
return output
450+
438451

439452
class ClassifierMixin:
440453
"""Mixin class for all classifiers in scikit-learn."""

sklearn/compose/_column_transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from joblib import Parallel, delayed
1616

1717
from ..base import clone, TransformerMixin
18+
from ..utils._estimator_html_repr import _VisualBlock
1819
from ..pipeline import _fit_transform_one, _transform_one, _name_estimators
1920
from ..preprocessing import FunctionTransformer
2021
from ..utils import Bunch
@@ -637,6 +638,11 @@ def _hstack(self, Xs):
637638
Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs]
638639
return np.hstack(Xs)
639640

641+
def _sk_visual_block_(self):
642+
names, transformers, name_details = zip(*self.transformers)
643+
return _VisualBlock('parallel', transformers,
644+
names=names, name_details=name_details)
645+
640646

641647
def _check_X(X):
642648
"""Use check_array only on lists and other non-array-likes / sparse"""

sklearn/ensemble/_stacking.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..base import clone
1414
from ..base import ClassifierMixin, RegressorMixin, TransformerMixin
1515
from ..base import is_classifier, is_regressor
16+
from ..utils._estimator_html_repr import _VisualBlock
1617

1718
from ._base import _fit_single_estimator
1819
from ._base import _BaseHeterogeneousEnsemble
@@ -233,6 +234,14 @@ def predict(self, X, **predict_params):
233234
self.transform(X), **predict_params
234235
)
235236

237+
def _sk_visual_block_(self, final_estimator):
238+
names, estimators = zip(*self.estimators)
239+
parallel = _VisualBlock('parallel', estimators, names=names,
240+
dash_wrapped=False)
241+
serial = _VisualBlock('serial', (parallel, final_estimator),
242+
dash_wrapped=False)
243+
return _VisualBlock('serial', [serial])
244+
236245

237246
class StackingClassifier(ClassifierMixin, _BaseStacking):
238247
"""Stack of estimators with a final classifier.
@@ -496,6 +505,15 @@ def transform(self, X):
496505
"""
497506
return self._transform(X)
498507

508+
def _sk_visual_block_(self):
509+
# If final_estimator's default changes then this should be
510+
# updated.
511+
if self.final_estimator is None:
512+
final_estimator = LogisticRegression()
513+
else:
514+
final_estimator = self.final_estimator
515+
return super()._sk_visual_block_(final_estimator)
516+
499517

500518
class StackingRegressor(RegressorMixin, _BaseStacking):
501519
"""Stack of estimators with a final regressor.
@@ -665,3 +683,12 @@ def transform(self, X):
665683
Prediction outputs for each estimator.
666684
"""
667685
return self._transform(X)
686+
687+
def _sk_visual_block_(self):
688+
# If final_estimator's default changes then this should be
689+
# updated.
690+
if self.final_estimator is None:
691+
final_estimator = RidgeCV()
692+
else:
693+
final_estimator = self.final_estimator
694+
return super()._sk_visual_block_(final_estimator)

sklearn/ensemble/_voting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..utils.validation import column_or_1d
3333
from ..utils.validation import _deprecate_positional_args
3434
from ..exceptions import NotFittedError
35+
from ..utils._estimator_html_repr import _VisualBlock
3536

3637

3738
class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):
@@ -104,6 +105,10 @@ def n_features_in_(self):
104105

105106
return self.estimators_[0].n_features_in_
106107

108+
def _sk_visual_block_(self):
109+
names, estimators = zip(*self.estimators)
110+
return _VisualBlock('parallel', estimators, names=names)
111+
107112

108113
class VotingClassifier(ClassifierMixin, _BaseVoting):
109114
"""Soft Voting/Majority Rule classifier for unfitted estimators.

sklearn/pipeline.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from joblib import Parallel, delayed
1919

2020
from .base import clone, TransformerMixin
21+
from .utils._estimator_html_repr import _VisualBlock
2122
from .utils.metaestimators import if_delegate_has_method
2223
from .utils import Bunch, _print_elapsed_time
2324
from .utils.validation import check_memory
@@ -623,6 +624,21 @@ def n_features_in_(self):
623624
# delegate to first step (which will call _check_is_fitted)
624625
return self.steps[0][1].n_features_in_
625626

627+
def _sk_visual_block_(self):
628+
_, estimators = zip(*self.steps)
629+
630+
def _get_name(name, est):
631+
if est is None or est == 'passthrough':
632+
return f'{name}: passthrough'
633+
# Is an estimator
634+
return f'{name}: {est.__class__.__name__}'
635+
names = [_get_name(name, est) for name, est in self.steps]
636+
name_details = [str(est) for est in estimators]
637+
return _VisualBlock('serial', estimators,
638+
names=names,
639+
name_details=name_details,
640+
dash_wrapped=False)
641+
626642

627643
def _name_estimators(estimators):
628644
"""Generate names for estimators."""
@@ -1004,6 +1020,10 @@ def n_features_in_(self):
10041020
# X is passed to all transformers so we just delegate to the first one
10051021
return self.transformer_list[0][1].n_features_in_
10061022

1023+
def _sk_visual_block_(self):
1024+
names, transformers = zip(*self.transformer_list)
1025+
return _VisualBlock('parallel', transformers, names=names)
1026+
10071027

10081028
def make_union(*transformers, **kwargs):
10091029
"""

0 commit comments

Comments
 (0)
0