8000 MAINT Parameters validation for `fastica` by kianelbo · Pull Request #24924 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Parameters validation for fastica #24924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 28, 2022
11 changes: 10 additions & 1 deletion sklearn/decomposition/_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..exceptions import ConvergenceWarning
from ..utils import check_array, as_float_array, check_random_state
from ..utils.validation import check_is_fitted
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params

__all__ = ["fastica", "FastICA"]

Expand Down Expand Up @@ -154,6 +154,14 @@ def _cube(x, fun_args):
return x**3, (3 * x**2).mean(axis=-1)


@validate_params(
{
"X": ["array-like"],
"return_X_mean": ["boolean"],
"compute_sources": ["boolean"],
"return_n_iter": ["boolean"],
}
)
def fastica(
X,
n_components=None,
Expand Down Expand Up @@ -319,6 +327,7 @@ def my_g(x):
whiten_solver=whiten_solver,
random_state=random_state,
)
est._validate_params()
S = est._fit_transform(X, compute_sources=compute_sources)

if est._whiten in ["unit-variance", "arbitrary-variance"]:
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_function_param_validation(func_module):
("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"),
("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
("sklearn.covariance.oas", "sklearn.covariance.OAS"),
("sklearn.decomposition.fastica", "sklearn.decomposition.FastICA"),
]


Expand Down
0