8000 ENH Add verbose to classifier regressor chains (#23977) · kasmith11/scikit-learn@be00837 · GitHub
[go: up one dir, main page]

Skip to content

Commit be00837

Browse files
lucyleeowefiegelglemaitrecmarmojjerphan
authored
ENH Add verbose to classifier regressor chains (scikit-learn#23977)
Co-authored-by: ericfiegel <efiegel01@gmail.com> Co-authored-by: Eric Fiegel <fiegel@usc.edu> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent ba496d8 commit be00837

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

doc/whats_new/v1.2.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,16 @@ Changelog
220220
- |Fix| Fixed error message of :class:`metrics.coverage_error` for 1D array input.
221221
:pr:`23548` by :user:`Hao Chun Chang <haochunchang>`.
222222

223+
:mod:`sklearn.multioutput`
224+
..........................
225+
226+
- |Feature| Added boolean `verbose` flag to classes:
227+
:class:`multioutput.ClassifierChain` and :class:`multioutput.RegressorChain`.
228+
:pr:`23977` by :user:`Eric Fiegel <efiegel>`,
229+
:user:`Chiara Marmo <cmarmo>`,
230+
:user:`Lucy Liu <lucyleeow>`, and
231+
:user:`Guillaume Lemaitre <glemaitre>`.
232+
223233
:mod:`sklearn.naive_bayes`
224234
..........................
225235

sklearn/multioutput.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
from .base import BaseEstimator, clone, MetaEstimatorMixin
2424
from .base import RegressorMixin, ClassifierMixin, is_classifier
2525
from .model_selection import cross_val_predict
26+
from .utils import check_random_state, _print_elapsed_time
2627
from .utils.metaestimators import available_if
27-
from .utils import check_random_state
28-
from .utils.validation import check_is_fitted, has_fit_parameter, _check_fit_params
2928
from .utils.multiclass import check_classification_targets
29+
from .utils.validation import (
30+
check_is_fitted,
31+
has_fit_parameter,
32+
_check_fit_params,
33+
)
3034
from .utils.fixes import delayed
3135
from .utils._param_validation import HasMethods
3236

@@ -538,11 +542,19 @@ def _check(self):
538542

539543

540544
class _BaseChain(BaseEstimator, metaclass=ABCMeta):
541-
def __init__(self, base_estimator, *, order=None, cv=None, random_state=None):
545+
def __init__(
546+
self, base_estimator, *, order=None, cv=None, random_state=None, verbose=False
547+
):
542548
self.base_estimator = base_estimator
543549
self.order = order
544550
self.cv = cv
545551
self.random_state = random_state
552+
self.verbose = verbose
553+
554+
def _log_message(self, *, estimator_idx, n_estimators, processing_msg):
555+
if not self.verbose:
556+
return None
557+
return f"({estimator_idx} of {n_estimators}) {processing_msg}"
546558

547559
@abstractmethod
548560
def fit(self, X, Y, **fit_params):
@@ -602,8 +614,14 @@ def fit(self, X, Y, **fit_params):
602614
del Y_pred_chain
603615

604616
for chain_idx, estimator in enumerate(self.estimators_):
617+
message = self._log_message(
618+
estimator_idx=chain_idx + 1,
619+
n_estimators=len(self.estimators_),
620+
processing_msg=f"Processing order {self.order_[chain_idx]}",
621+
)
605622
y = Y[:, self.order_[chain_idx]]
606-
estimator.fit(X_aug[:, : (X.shape[1] + chain_idx)], y, **fit_params)
623+
with _print_elapsed_time("Chain", message):
624+
estimator.fit(X_aug[:, : (X.shape[1] + chain_idx)], y, **fit_params)
607625
if self.cv is not None and chain_idx < len(self.estimators_) - 1:
608626
col_idx = X.shape[1] + chain_idx
609627
cv_result = cross_val_predict(
@@ -702,6 +720,11 @@ class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
702720
Pass an int for reproducible output across multiple function calls.
703721
See :term:`Glossary <random_state>`.
704722
723+
verbose : bool, default=False
724+
If True, chain progress is output as each model is completed.
725+
726+
.. versionadded:: 1.2
727+
705728
Attributes
706729
----------
707730
classes_ : list
@@ -903,6 +926,11 @@ class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
903926
Pass an int for reproducible output across multiple function calls.
904927
See :term:`Glossary <random_state>`.
905928
929+
verbose : bool, default=False
930+
If True, chain progress is output as each model is completed.
931+
932+
.. versionadded:: 1.2
933+
906934
Attributes
907935
----------
908936
estimators_ : list

sklearn/tests/test_multioutput.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import scipy.sparse as sp
44
from joblib import cpu_count
5+
import re
56

67
from sklearn.utils._testing import assert_almost_equal
78
from sklearn.utils._testing import assert_array_equal
@@ -10,6 +11,8 @@
1011
from sklearn.base import clone
1112
from sklearn.datasets import make_classification
1213
from sklearn.datasets import load_linnerud
14+
from sklearn.datasets import make_multilabel_classification
15+
from sklearn.datasets import make_regression
1316
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
1417
from sklearn.exceptions import NotFittedError
1518
from sklearn.linear_model import Lasso
@@ -18,15 +21,17 @@
1821
from sklearn.linear_model import Ridge
1922
from sklearn.linear_model import SGDClassifier
2023
from sklearn.linear_model import SGDRegressor
24+
from sklearn.linear_model import LinearRegression
2125
from sklearn.metrics import jaccard_score, mean_squared_error
2226
from sklearn.multiclass import OneVsRestClassifier
2327
from sklearn.multioutput import ClassifierChain, RegressorChain
2428
from sklearn.multioutput import MultiOutputClassifier
2529
from sklearn.multioutput import MultiOutputRegressor
2630
from sklearn.svm import LinearSVC
31+
from sklearn.tree import DecisionTreeClassifier
2732
from sklearn.base import ClassifierMixin
2833
from sklearn.utils import shuffle
29-
from sklearn.model_selection import GridSearchCV
34+
from sklearn.model_selection import GridSearchCV, train_test_split
3035
from sklearn.dummy import DummyRegressor, DummyClassifier
3136
from sklearn.pipeline import make_pipeline
3237
from sklearn.impute import SimpleImputer
@@ -702,6 +707,47 @@ def test_classifier_chain_tuple_invalid_order():
702707
chain.fit(X, y)
703708

704709

710+
def test_classifier_chain_verbose(capsys):
711+
X, y = make_multilabel_classification(
712+
n_samples=100, n_features=5, n_classes=3, n_labels=3, random_state=0
713+
)
714+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
715+
716+
pattern = (
717+
r"\[Chain\].*\(1 of 3\) Processing order 0, total=.*\n"
718+
r"\[Chain\].*\(2 of 3\) Processing order 1, total=.*\n"
719+
r"\[Chain\].*\(3 of 3\) Processing order 2, total=.*\n$"
720+
)
721+
722+
classifier = ClassifierChain(
723+
DecisionTreeClassifier(),
724+
order=[0, 1, 2],
725+
random_state=0,
726+
verbose=True,
727+
)
728+
classifier.fit(X_train, y_train)
729+
assert re.match(pattern, capsys.readouterr()[0])
730+
731+
732+
def test_regressor_chain_verbose(capsys):
733+
X, y = make_regression(n_samples=125, n_targets=3, random_state=0)
734+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
735+
736+
pattern = (
737+
r"\[Chain\].*\(1 of 3\) Processing order 1, total=.*\n"
738+
r"\[Chain\].*\(2 of 3\) Processing order 0, total=.*\n"
739+
r"\[Chain\].*\(3 of 3\) Processing order 2, total=.*\n$"
740+
)
741+
regressor = RegressorChain(
742+
LinearRegression(),
743+
order=[1, 0, 2],
744+
random_state=0,
745+
verbose=True,
746+
)
747+
regressor.fit(X_train, y_train)
748+
assert re.match(pattern, capsys.readouterr()[0])
749+
750+
705751
def test_multioutputregressor_ducktypes_fitted_estimator():
706752
"""Test that MultiOutputRegressor checks the fitted estimator for
707753
predict. Non-regression test for #16549."""

0 commit comments

Comments
 (0)
0