8000 MAINT Parameters validation for 'fastica' · scikit-learn/scikit-learn@f5097d5 · GitHub
[go: up one dir, main page]

Skip to content

Commit f5097d5

Browse files
committed
MAINT Parameters validation for 'fastica'
1 parent dbde1da commit f5097d5

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

sklearn/decomposition/_fastica.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..exceptions import ConvergenceWarning
2020
from ..utils import check_array, as_float_array, check_random_state
2121
from ..utils.validation import check_is_fitted
22-
from ..utils._param_validation import Hidden, Interval, StrOptions
22+
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
2323

2424
__all__ = ["fastica", "FastICA"]
2525

@@ -154,6 +154,31 @@ def _cube(x, fun_args):
154154
return x**3, (3 * x**2).mean(axis=-1)
155155

156156

157+
@validate_params(
158+
{
159+
"X": ["array-like"],
160+
"n_components": [Interval(Integral, 1, None, closed="left"), None],
161+
"algorithm": [StrOptions({"parallel", "deflation"})],
162+
"whiten": [
163+
Hidden(StrOptions({"warn"})),
164+
StrOptions({"arbitrary-variance", "unit-variance"}),
165+
"boolean",
166+
],
167+
"fun": [StrOptions({"logcosh", "exp", "cube"}), callable],
168+
"fun_args": [dict, None],
169+
"max_iter": [Interval(Integral, 1, None, closed="left")],
170+
"tol": [Interval(Real, 0.0, None, closed="left")],
171+
"w_init": ["array-like", None],
172+
"whiten_solver": [StrOptions({"eigh", "svd"})],
173+
"random_state": ["random_state"],
174+
"n_clusters": [Interval(Integral, 1, None, closed="left")],
175+
"x_squared_norms": ["array-like", None],
176+
"random_state": ["random_state"],
177+
"return_X_mean": ["boolean"],
178+
"compute_sources": ["boolean"],
179+
"return_n_iter": ["boolean"],
180+
}
181+
)
157182
def fastica(
158183
X,
159184
n_components=None,

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
PARAM_VALIDATION_FUNCTION_LIST = [
1212
"sklearn.cluster.kmeans_plusplus",
13+
"sklearn.decomposition.fastica",
1314
]
1415

1516

0 commit comments

Comments
 (0)
0