|
1 | | -import numpy as np |
2 | 1 | import pytest |
3 | | -import warnings |
4 | 2 |
|
5 | 3 | import pickle |
6 | 4 |
|
7 | | -from sklearn.utils.metaestimators import if_delegate_has_method |
8 | 5 | from sklearn.utils.metaestimators import available_if |
9 | 6 |
|
10 | 7 |
|
11 | | -class Prefix: |
12 | | - def func(self): |
13 | | - pass |
14 | | - |
15 | | - |
16 | | -class MockMetaEstimator: |
17 | | - """This is a mock meta estimator""" |
18 | | - |
19 | | - a_prefix = Prefix() |
20 | | - |
21 | | - @if_delegate_has_method(delegate="a_prefix") |
22 | | - def func(self): |
23 | | - """This is a mock delegated function""" |
24 | | - pass |
25 | | - |
26 | | - |
27 | | -@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated") |
28 | | -def test_delegated_docstring(): |
29 | | - assert "This is a mock delegated function" in str( |
30 | | - MockMetaEstimator.__dict__["func"].__doc__ |
31 | | - ) |
32 | | - assert "This is a mock delegated function" in str(MockMetaEstimator.func.__doc__) |
33 | | - assert "This is a mock delegated function" in str(MockMetaEstimator().func.__doc__) |
34 | | - |
35 | | - |
36 | | -class MetaEst: |
37 | | - """A mock meta estimator""" |
38 | | - |
39 | | - def __init__(self, sub_est, better_sub_est=None): |
40 | | - self.sub_est = sub_est |
41 | | - self.better_sub_est = better_sub_est |
42 | | - |
43 | | - @if_delegate_has_method(delegate="sub_est") |
44 | | - def predict(self): |
45 | | - pass |
46 | | - |
47 | | - |
48 | | -class MetaEstTestTuple(MetaEst): |
49 | | - """A mock meta estimator to test passing a tuple of delegates""" |
50 | | - |
51 | | - @if_delegate_has_method(delegate=("sub_est", "better_sub_est")) |
52 | | - def predict(self): |
53 | | - pass |
54 | | - |
55 | | - |
56 | | -class MetaEstTestList(MetaEst): |
57 | | - """A mock meta estimator to test passing a list of delegates""" |
58 | | - |
59 | | - @if_delegate_has_method(delegate=["sub_est", "better_sub_est"]) |
60 | | - def predict(self): |
61 | | - pass |
62 | | - |
63 | | - |
64 | | -class HasPredict: |
65 | | - """A mock sub-estimator with predict method""" |
66 | | - |
67 | | - def predict(self): |
68 | | - pass |
69 | | - |
70 | | - |
71 | | -class HasNoPredict: |
72 | | - """A mock sub-estimator with no predict method""" |
73 | | - |
74 | | - pass |
75 | | - |
76 | | - |
77 | | -class HasPredictAsNDArray: |
78 | | - """A mock sub-estimator where predict is a NumPy array""" |
79 | | - |
80 | | - predict = np.ones((10, 2), dtype=np.int64) |
81 | | - |
82 | | - |
83 | | -@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated") |
84 | | -def test_if_delegate_has_method(): |
85 | | - assert hasattr(MetaEst(HasPredict()), "predict") |
86 | | - assert not hasattr(MetaEst(HasNoPredict()), "predict") |
87 | | - assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), "predict") |
88 | | - assert hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), "predict") |
89 | | - assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), "predict") |
90 | | - assert not hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), "predict") |
91 | | - assert hasattr(MetaEstTestList(HasPredict(), HasPredict()), "predict") |
92 | | - |
93 | | - |
94 | 8 | class AvailableParameterEstimator: |
95 | 9 | """This estimator's `available` parameter toggles the presence of a method""" |
96 | 10 |
|
@@ -137,29 +51,6 @@ def test_available_if_unbound_method(): |
137 | 51 | AvailableParameterEstimator.available_func(est) |
138 | 52 |
|
139 | 53 |
|
140 | | -@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated") |
141 | | -def test_if_delegate_has_method_numpy_array(): |
142 | | -
10327
"""Check that we can check for an attribute that is a NumPy array. |
143 | | -
|
144 | | - This is a non-regression test for: |
145 | | - https://github.com/scikit-learn/scikit-learn/issues/21144 |
146 | | - """ |
147 | | - estimator = MetaEst(HasPredictAsNDArray()) |
148 | | - assert hasattr(estimator, "predict") |
149 | | - |
150 | | - |
151 | | -def test_if_delegate_has_method_deprecated(): |
152 | | - """Check the deprecation warning of if_delegate_has_method""" |
153 | | - # don't warn when creating the decorator |
154 | | - with warnings.catch_warnings(): |
155 | | - warnings.simplefilter("error", FutureWarning) |
156 | | - _ = if_delegate_has_method(delegate="predict") |
157 | | - |
158 | | - # Only when calling it |
159 | | - with pytest.warns(FutureWarning, match="if_delegate_has_method was deprecated"): |
160 | | - hasattr(MetaEst(HasPredict()), "predict") |
161 | | - |
162 | | - |
163 | 54 | def test_available_if_methods_can_be_pickled(): |
164 | 55 | """Check that available_if methods can be pickled. |
165 | 56 |
|
|
0 commit comments