8000 WIP ENH Added `auto` option to `FastICA.whiten_solver` by Micky774 · Pull Request #23616 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

WIP ENH Added auto option to FastICA.whiten_solver #23616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ Changelog
:pr:`11860` by :user:`Pierre Ablin <pierreablin>`,
:pr:`22527` by :user:`Meekail Zain <micky774>` and `Thomas Fan`_.

- |Enhancement| Added `"auto"` option to `whiten_solver` parameter, of
:class:`decomposition.FastICA` which then uses the `eigh` solver if
`X.shape[0] >= 50*X.shape[1]` and the `svd` solver otherwise.
:pr:`23616` by :user:`Meekail Zain <micky774>`.

- |Enhancement| :class:`decomposition.LatentDirichletAllocation` now preserves dtype
for `numpy.float32` input. :pr:`24528` by :user:`Takeshi Oura <takoika>` and
:user:`Jérémie du Boisberranger <jeremiedbb>`.
Expand Down
57 changes: 47 additions & 10 deletions sklearn/decomposition/_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fastica(
max_iter=200,
tol=1e-04,
w_init=None,
whiten_solver="svd",
whiten_solver="warn",
random_state=None,
return_X_mean=False,
compute_sources=True,
Expand Down Expand Up @@ -232,8 +232,9 @@ def my_g(x):
Initial un-mixing array. If `w_init=None`, then an array of values
drawn from a normal distribution is used.

whiten_solver : {"eigh", "svd"}, default="svd"
The solver to use for whitening.
whiten_solver : {"auto", "eigh", "svd"}, default="auto"
The solver to use for whitening. Note that different solvers may
produce different solutions. See `sign_flip` for details.

- "svd" is more stable numerically if the problem is degenerate, and
often faster when `n_samples <= n_features`.
Expand All @@ -242,8 +243,14 @@ def my_g(x):
`n_samples >= n_features`, and can be faster when
`n_samples >= 50 * n_features`.

- "auto" uses the `eigh` solver when `n_samples >= 50 * n_features`

.. versionadded:: 1.2

.. versionchanged:: 1.4
The default value for `whiten_solver` will change to "auto" in
version 1.4.

random_state : int, RandomState instance or None, default=None
Used to initialize ``w_init`` when not specified, with a
normal distribution. Pass an int, for reproducible results
Expand Down Expand Up @@ -395,8 +402,9 @@ def my_g(x):
Initial un-mixing array. If `w_init=None`, then an array of values
drawn from a normal distribution is used.

whiten_solver : {"eigh", "svd"}, default="svd"
The solver to use for whitening.
whiten_solver : {"auto", "eigh", "svd"}, default="auto"
The solver to use for whitening. Note that different solvers may
produce different solutions. See `sign_flip` for details.

- "svd" is more stable numerically if the problem is degenerate, and
often faster when `n_samples <= n_features`.
Expand All @@ -405,8 +413,15 @@ def my_g(x):
`n_samples >= n_features`, and can be faster when
`n_samples >= 50 * n_features`.

- "auto" uses the `eigh` solver when `n_samples >= 50 * n_features`,
`svd` otherwise.

.. versionadded:: 1.2

.. versionchanged:: 1.4
The default value for `whiten_solver` will change to "auto" in
version 1.4.

random_state : int, RandomState instance or None, default=None
Used to initialize ``w_init`` when not specified, with a
normal distribution. Pass an int, for reproducible results
Expand Down Expand Up @@ -469,7 +484,8 @@ def my_g(x):
>>> X, _ = load_digits(return_X_y=True)
>>> transformer = FastICA(n_components=7,
... random_state=0,
... whiten='unit-variance')
... whiten='unit-variance',
... whiten_solver='auto')
>>> X_transformed = transformer.fit_transform(X)
>>> X_transformed.shape
(1797, 7)
Expand All @@ -488,7 +504,10 @@ def my_g(x):
"max_iter": [Interval(Integral, 1, None, closed="left")],
"tol": [Interval(Real, 0.0, None, closed="left")],
"w_init": ["array-like", None],
"whiten_solver": [StrOptions({"eigh", "svd"})],
"whiten_solver": [
StrOptions({"eigh", "svd", "auto"}),
Hidden(StrOptions({"warn"})),
],
"random_state": ["random_state"],
}

Expand All @@ -503,7 +522,7 @@ def __init__(
max_iter=200,
tol=1e-4,
w_init=None,
whiten_solver="svd",
whiten_solver="warn",
random_state=None,
):
super().__init__()
Expand Down Expand Up @@ -558,6 +577,24 @@ def _fit_transform(self, X, compute_sources=False):
XT = self._validate_data(
X, copy=self._whiten, dtype=[np.float64, np.float32], ensure_min_samples=2
).T

# Benchmark validated heuristic
self._whiten_solver = self.whiten_solver
if self._whiten_solver == "warn" and self._whiten:
warnings.warn(
"From version 1.4 `whiten_solver='auto'` will be used by default."
" Manually set the value of `whiten_solver` to suppress this message."
" Note that `whiten_solver='auto'` may change the numerical output"
" of the model, but not its correctness. This can be alleviated by"
" setting `sign_flip=True`. See the `whiten_solver` and `sign_flip`"
" descriptions for details.",
FutureWarning,
)
self._whiten_solver = "svd"

if self._whiten_solver == "auto":
self._whiten_solver = "eigh" if XT.shape[1] >= 50 * XT.shape[0] else "svd"

fun_args = {} if self.fun_args is None else self.fun_args
random_state = check_random_state(self.random_state)

Expand Down Expand Up @@ -596,7 +633,7 @@ def g(x, fun_args):
XT -= X_mean[:, np.newaxis]

# Whitening and preprocessing by PCA
if self.whiten_solver == "eigh":
if self._whiten_solver == "eigh":
# Faster when num_samples >> n_features
d, u = linalg.eigh(XT.dot(X))
sort_indices = np.argsort(d)[::-1]
Expand All @@ -611,7 +648,7 @@ def g(x, fun_args):
d[degenerate_idx] = eps # For numerical issues
np.sqrt(d, out=d)
d, u = d[sort_indices], u[:, sort_indices]
elif self.whiten_solver == "svd":
elif self._whiten_solver == "svd":
u, d = linalg.svd(XT, full_matrices=False, check_finite=False)[:2]

# Give consistent eigenvectors for both svd solvers
Expand Down
121 changes: 102 additions & 19 deletions sklearn/decomposition/tests/test_fastica.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def test_fastica_attributes_dtypes(global_dtype):
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(global_dtype, copy=False)
fica = FastICA(
n_components=5, max_iter=1000, whiten="unit-variance", random_state=0
n_components=5,
max_iter=1000,
whiten="unit-variance",
random_state=0,
whiten_solver="auto",
).fit(X)
assert fica.components_.dtype == global_dtype
assert fica.mixing_.dtype == global_dtype
Expand All @@ -62,7 +66,11 @@ def test_fastica_return_dtypes(global_dtype):
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10)).astype(global_dtype, copy=False)
k_, mixing_, s_ = fastica(
X, max_iter=1000, whiten="unit-variance", random_state=rng
X,
max_iter=1000,
whiten="unit-variance",
random_state=rng,
whiten_solver="auto",
)
assert k_.dtype == global_dtype
assert mixing_.dtype == global_dtype
Expand Down Expand Up @@ -107,18 +115,29 @@ def g_test(x):
for algo, nl, whiten in itertools.product(algos, nls, whitening):
if whiten:
k_, mixing_, s_ = fastica(
m.T, fun=nl, whiten=whiten, algorithm=algo, random_state=rng
m.T,
fun=nl,
whiten=whiten,
algorithm=algo,
random_state=rng,
whiten_solver="auto",
)
with pytest.raises(ValueError):
fastica(m.T, fun=np.tanh, whiten=whiten, algorithm=algo)
fastica(
m.T,
fun=np.tanh,
whiten=whiten,
algorithm=algo,
whiten_solver="auto",
)
else:
pca = PCA(n_components=2, whiten=True, random_state=rng)
X = pca.fit_transform(m.T)
k_, mixing_, s_ = fastica(
X, fun=nl, algorithm=algo, whiten=False, random_state=rng
)
with pytest.raises(ValueError):
fastica(X, fun=np.tanh, algorithm=algo)
fastica(X, fun=np.tanh, algorithm=algo, whiten_solver="auto")
s_ = s_.T
# Check that the mixing model described in the docstring holds:
if whiten:
Expand Down Expand Up @@ -149,9 +168,15 @@ def g_test(x):

# Test FastICA class
_, _, sources_fun = fastica(
m.T, fun=nl, algorithm=algo, random_state=global_random_seed
m.T,
fun=nl,
algorithm=algo,
random_state=global_random_seed,
whiten_solver="auto",
)
ica = FastICA(
fun=nl, algorithm=algo, random_state=global_random_seed, whiten_solver="auto"
)
ica = FastICA(fun=nl, algorithm=algo, random_state=global_random_seed)
sources = ica.fit_transform(m.T)
assert ica.components_.shape == (2, 2)
assert sources.shape == (1000, 2)
Expand All @@ -163,7 +188,7 @@ def g_test(x):

assert ica.mixing_.shape == (2, 2)

ica = FastICA(fun=np.tanh, algorithm=algo)
ica = FastICA(fun=np.tanh, algorithm=algo, whiten_solver="auto")
with pytest.raises(ValueError):
ica.fit(m.T)

Expand Down Expand Up @@ -204,7 +229,12 @@ def test_fastica_convergence_fail():
)
with pytest.warns(ConvergenceWarning, match=warn_msg):
ica = FastICA(
algorithm="parallel", n_components=2, random_state=rng, max_iter=2, tol=0.0
algorithm="parallel",
n_components=2,
random_state=rng,
max_iter=2,
tol=0.0,
whiten_solver="auto",
)
ica.fit(m.T)

Expand Down Expand Up @@ -233,7 +263,11 @@ def test_non_square_fastica(add_noise):
center_and_norm(m)

k_, mixing_, s_ = fastica(
m.T, n_components=2, whiten="unit-variance", random_state=rng
m.T,
n_components=2,
whiten="unit-variance",
random_state=rng,
whiten_solver="auto",
)
s_ = s_.T

Expand Down Expand Up @@ -271,7 +305,11 @@ def test_fit_transform(global_random_seed, global_dtype):
n_components_ = n_components if n_components is not None else X.shape[1]

ica = FastICA(
n_components=n_components, max_iter=max_iter, whiten=whiten, random_state=0
n_components=n_components,
max_iter=max_iter,
whiten=whiten,
random_state=0,
whiten_solver="auto",
)
with warnings.catch_warnings():
# make sure that numerical errors do not cause sqrt of negative
Expand All @@ -285,7 +323,11 @@ def test_fit_transform(global_random_seed, global_dtype):
assert Xt.shape == (X.shape[0], n_components_)

ica2 = FastICA(
n_components=n_components, max_iter=max_iter, whiten=whiten, random_state=0
n_components=n_components,
max_iter=max_iter,
whiten=whiten,
random_state=0,
whiten_solver="auto",
)
with warnings.catch_warnings():
# make sure that numerical errors do not cause sqrt of negative
Expand Down Expand Up @@ -325,7 +367,9 @@ def test_inverse_transform(
rng = np.random.RandomState(global_random_seed)
X = rng.random_sample((n_samples, 10)).astype(global_dtype)

ica = FastICA(n_components=n_components, random_state=rng, whiten=whiten)
ica = FastICA(
n_components=n_components, random_state=rng, whiten=whiten, whiten_solver="auto"
)
with warnings.catch_warnings():
# For some dataset (depending on the value of global_dtype) the model
# can fail to converge but this should not impact the definition of
Expand Down Expand Up @@ -360,11 +404,11 @@ def test_fastica_errors():
X = rng.random_sample((n_samples, n_features))
w_init = rng.randn(n_features + 1, n_features + 1)
with pytest.raises(ValueError, match=r"alpha must be in \[1,2\]"):
fastica(X, fun_args= 10000 {"alpha": 0})
fastica(X, fun_args={"alpha": 0}, whiten_solver="auto")
with pytest.raises(
ValueError, match="w_init has invalid shape.+" r"should be \(3L?, 3L?\)"
):
fastica(X, w_init=w_init)
fastica(X, w_init=w_init, whiten_solver="auto")


def test_fastica_whiten_unit_variance():
Expand All @@ -375,13 +419,18 @@ def test_fastica_whiten_unit_variance():
rng = np.random.RandomState(0)
X = rng.random_sample((100, 10))
n_components = X.shape[1]
ica = FastICA(n_components=n_components, whiten="unit-variance", random_state=0)
ica = FastICA(
n_components=n_components,
whiten="unit-variance",
random_state=0,
whiten_solver="auto",
)
Xt = ica.fit_transform(X)

assert np.var(Xt) == pytest.approx(1.0)


@pytest.mark.parametrize("ica", [FastICA(), FastICA(whiten=True)])
@pytest.mark.parametrize("ica", [FastICA(), FastICA(whiten=True, whiten_solver="auto")])
def test_fastica_whiten_default_value_deprecation(ica):
"""Test FastICA whiten default value deprecation.

Expand All @@ -407,7 +456,9 @@ def test_fastica_whiten_backwards_compatibility():
with pytest.warns(FutureWarning):
Xt_on_default = default_ica.fit_transform(X)

ica = FastICA(n_components=n_components, whiten=True, random_state=0)
ica = FastICA(
n_components=n_components, whiten=True, random_state=0, whiten_solver="auto"
)
with pytest.warns(FutureWarning):
Xt = ica.fit_transform(X)

Expand Down Expand Up @@ -445,7 +496,11 @@ def test_fastica_output_shape(whiten, return_X_mean, return_n_iter):
expected_len = 3 + return_X_mean + return_n_iter

out = fastica(
X, whiten=whiten, return_n_iter=return_n_iter, return_X_mean=return_X_mean
X,
whiten=whiten,
return_n_iter=return_n_iter,
return_X_mean=return_X_mean,
whiten_solver="auto",
)

assert len(out) == expected_len
Expand Down Expand Up @@ -497,3 +552,31 @@ def test_fastica_eigh_low_rank_warning(global_random_seed):
msg = "There are some small singular values"
with pytest.warns(UserWarning, match=msg):
ica.fit(X)


# TODO(1.4): to be removed
def test_fastica_whiten_solver_future_warning():
rng = np.random.RandomState(0)
X = rng.random_sample((10, 10))

ica = FastICA(random_state=rng, whiten="unit-variance")
msg = "From version 1.4 `whiten_solver='auto'` will be used by default."
with pytest.warns(FutureWarning, match=msg):
ica.fit_transform(X)
assert ica.whiten_solver == "warn"
assert ica._whiten_solver == "svd"

# Test that it doesn't warn if whiten is explicitly set to False
ica = FastICA(random_state=rng, whiten=False)
ica.fit_transform(X)


@pytest.mark.parametrize("n_samples", (199, 200))
def test_fastica_whiten_solver_auto(n_samples):
"""Check the heuristic that automatically chooses the solver."""
rng = np.random.RandomState(0)
X = rng.random_sample((n_samples, 4))
ica = FastICA(random_state=rng, whiten="unit-variance", whiten_solver="auto")
ica.fit_transform(X)
solver = "eigh" if X.shape[0] >= 50 * X.shape[1] else "svd"
assert ica._whiten_solver == solver
Loading
0