|
17 | 17 | from sklearn.datasets import load_iris
|
18 | 18 | from sklearn.decomposition import PCA, TruncatedSVD
|
19 | 19 | from sklearn.dummy import DummyRegressor
|
20 |
| -from sklearn.ensemble import HistGradientBoostingClassifier |
| 20 | +from sklearn.ensemble import ( |
| 21 | + HistGradientBoostingClassifier, |
| 22 | + RandomForestClassifier, |
| 23 | + RandomTreesEmbedding, |
| 24 | +) |
21 | 25 | from sklearn.exceptions import NotFittedError
|
22 | 26 | from sklearn.feature_extraction.text import CountVectorizer
|
23 | 27 | from sklearn.feature_selection import SelectKBest, f_classif
|
|
27 | 31 | from sklearn.model_selection import train_test_split
|
28 | 32 | from sklearn.neighbors import LocalOutlierFactor
|
29 | 33 | from sklearn.pipeline import FeatureUnion, Pipeline, make_pipeline, make_union
|
30 |
| -from sklearn.preprocessing import StandardScaler |
| 34 | +from sklearn.preprocessing import FunctionTransformer, StandardScaler |
31 | 35 | from sklearn.svm import SVC
|
32 | 36 | from sklearn.utils._metadata_requests import COMPOSITE_METHODS, METHODS
|
33 | 37 | from sklearn.utils._testing import (
|
@@ -1828,5 +1832,26 @@ def test_routing_passed_metadata_not_supported(method):
|
1828 | 1832 | getattr(pipe, method)([[1]], sample_weight=[1], prop="a")
|
1829 | 1833 |
|
1830 | 1834 |
|
| 1835 | +@pytest.mark.usefixtures("enable_slep006") |
| 1836 | +def test_pipeline_with_estimator_with_len(): |
| 1837 | + """Test that pipeline works with estimators that have a `__len__` method.""" |
| 1838 | + pipe = Pipeline( |
| 1839 | + [("trs", RandomTreesEmbedding()), ("estimator", RandomForestClassifier())] |
| 1840 | + ) |
| 1841 | + pipe.fit([[1]], [1]) |
| 1842 | + pipe.predict([[1]]) |
| 1843 | + |
| 1844 | + |
| 1845 | +@pytest.mark.usefixtures("enable_slep006") |
| 1846 | +@pytest.mark.parametrize("last_step", [None, "passthrough"]) |
| 1847 | +def test_pipeline_with_no_last_step(last_step): |
| 1848 | + """Test that the pipeline works when there is not last step. |
| 1849 | +
|
| 1850 | + It should just ignore and pass through the data on transform. |
| 1851 | + """ |
| 1852 | + pipe = Pipeline([("trs", FunctionTransformer()), ("estimator", last_step)]) |
| 1853 | + assert pipe.fit([[1]], [1]).transform([[1], [2], [3]]) == [[1], [2], [3]] |
| 1854 | + |
| 1855 | + |
1831 | 1856 | # End of routing tests
|
1832 | 1857 | # ====================
|
0 commit comments