|
19 | 19 | from ..exceptions import ConvergenceWarning
|
20 | 20 | from ..utils import check_array, as_float_array, check_random_state
|
21 | 21 | from ..utils.validation import check_is_fitted
|
22 |
| -from ..utils._param_validation import Hidden, Interval, StrOptions |
| 22 | +from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params |
23 | 23 |
|
24 | 24 | __all__ = ["fastica", "FastICA"]
|
25 | 25 |
|
@@ -154,6 +154,31 @@ def _cube(x, fun_args):
|
154 | 154 | return x**3, (3 * x**2).mean(axis=-1)
|
155 | 155 |
|
156 | 156 |
|
| 157 | +@validate_params( |
| 158 | + { |
| 159 | + "X": ["array-like"], |
| 160 | + "n_components": [Interval(Integral, 1, None, closed="left"), None], |
| 161 | + "algorithm": [StrOptions({"parallel", "deflation"})], |
| 162 | + "whiten": [ |
| 163 | + Hidden(StrOptions({"warn"})), |
| 164 | + StrOptions({"arbitrary-variance", "unit-variance"}), |
| 165 | + "boolean", |
| 166 | + ], |
| 167 | + "fun": [StrOptions({"logcosh", "exp", "cube"}), callable], |
| 168 | + "fun_args": [dict, None], |
| 169 | + "max_iter": [Interval(Integral, 1, None, closed="left")], |
| 170 | + "tol": [Interval(Real, 0.0, None, closed="left")], |
| 171 | + "w_init": ["array-like", None], |
| 172 | + "whiten_solver": [StrOptions({"eigh", "svd"})], |
| 173 | + "random_state": ["random_state"], |
| 174 | + "n_clusters": [Interval(Integral, 1, None, closed="left")], |
| 175 | + "x_squared_norms": ["array-like", None], |
| 176 | + "random_state": ["random_state"], |
| 177 | + "return_X_mean": ["boolean"], |
| 178 | + "compute_sources": ["boolean"], |
| 179 | + "return_n_iter": ["boolean"], |
| 180 | + } |
| 181 | +) |
157 | 182 | def fastica(
|
158 | 183 | X,
|
159 | 184 | n_components=None,
|
|
0 commit comments