|
2 | 2 | import numpy as np
|
3 | 3 | import scipy.sparse as sp
|
4 | 4 | from joblib import cpu_count
|
| 5 | +import re |
5 | 6 |
|
6 | 7 | from sklearn.utils._testing import assert_almost_equal
|
7 | 8 | from sklearn.utils._testing import assert_array_equal
|
|
10 | 11 | from sklearn.base import clone
|
11 | 12 | from sklearn.datasets import make_classification
|
12 | 13 | from sklearn.datasets import load_linnerud
|
| 14 | +from sklearn.datasets import make_multilabel_classification |
| 15 | +from sklearn.datasets import make_regression |
13 | 16 | from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
|
14 | 17 | from sklearn.exceptions import NotFittedError
|
15 | 18 | from sklearn.linear_model import Lasso
|
|
18 | 21 | from sklearn.linear_model import Ridge
|
19 | 22 | from sklearn.linear_model import SGDClassifier
|
20 | 23 | from sklearn.linear_model import SGDRegressor
|
| 24 | +from sklearn.linear_model import LinearRegression |
21 | 25 | from sklearn.metrics import jaccard_score, mean_squared_error
|
22 | 26 | from sklearn.multiclass import OneVsRestClassifier
|
23 | 27 | from sklearn.multioutput import ClassifierChain, RegressorChain
|
24 | 28 | from sklearn.multioutput import MultiOutputClassifier
|
25 | 29 | from sklearn.multioutput import MultiOutputRegressor
|
26 | 30 | from sklearn.svm import LinearSVC
|
| 31 | +from sklearn.tree import DecisionTreeClassifier |
27 | 32 | from sklearn.base import ClassifierMixin
|
28 | 33 | from sklearn.utils import shuffle
|
29 |
| -from sklearn.model_selection import GridSearchCV |
| 34 | +from sklearn.model_selection import GridSearchCV, train_test_split |
30 | 35 | from sklearn.dummy import DummyRegressor, DummyClassifier
|
31 | 36 | from sklearn.pipeline import make_pipeline
|
32 | 37 | from sklearn.impute import SimpleImputer
|
@@ -702,6 +707,47 @@ def test_classifier_chain_tuple_invalid_order():
|
702 | 707 | chain.fit(X, y)
|
703 | 708 |
|
704 | 709 |
|
| 710 | +def test_classifier_chain_verbose(capsys): |
| 711 | + X, y = make_multilabel_classification( |
| 712 | + n_samples=100, n_features=5, n_classes=3, n_labels=3, random_state=0 |
| 713 | + ) |
| 714 | + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) |
| 715 | + |
| 716 | + pattern = ( |
| 717 | + r"\[Chain\].*\(1 of 3\) Processing order 0, total=.*\n" |
| 718 | + r"\[Chain\].*\(2 of 3\) Processing order 1, total=.*\n" |
| 719 | + r"\[Chain\].*\(3 of 3\) Processing order 2, total=.*\n$" |
| 720 | + ) |
| 721 | + |
| 722 | + classifier = ClassifierChain( |
| 723 | + DecisionTreeClassifier(), |
| 724 | + order=[0, 1, 2], |
| 725 | + random_state=0, |
| 726 | + verbose=True, |
| 727 | + ) |
| 728 | + classifier.fit(X_train, y_train) |
| 729 | + assert re.match(pattern, capsys.readouterr()[0]) |
| 730 | + |
| 731 | + |
| 732 | +def test_regressor_chain_verbose(capsys): |
| 733 | + X, y = make_regression(n_samples=125, n_targets=3, random_state=0) |
| 734 | + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) |
| 735 | + |
| 736 | + pattern = ( |
| 737 | + r"\[Chain\].*\(1 of 3\) Processing order 1, total=.*\n" |
| 738 | + r"\[Chain\].*\(2 of 3\) Processing order 0, total=.*\n" |
| 739 | + r"\[Chain\].*\(3 of 3\) Processing order 2, total=.*\n$" |
| 740 | + ) |
| 741 | + regressor = RegressorChain( |
| 742 | + LinearRegression(), |
| 743 | + order=[1, 0, 2], |
| 744 | + random_state=0, |
| 745 | + verbose=True, |
| 746 | + ) |
| 747 | + regressor.fit(X_train, y_train) |
| 748 | + assert re.match(pattern, capsys.readouterr()[0]) |
| 749 | + |
| 750 | + |
705 | 751 | def test_multioutputregressor_ducktypes_fitted_estimator():
|
706 | 752 | """Test that MultiOutputRegressor checks the fitted estimator for
|
707 | 753 | predict. Non-regression test for #16549."""
|
|
0 commit comments