8000 ENH Warn future change of default init in NMF (#18525) · jayzed82/scikit-learn@9d3a7f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d3a7f5

Browse files
cmarmojayzed82
authored andcommitted
ENH Warn future change of default init in NMF (scikit-learn#18525)
1 parent 17b005a commit 9d3a7f5

File tree

6 files changed

+79
-21
lines changed

6 files changed

+79
-21
lines changed

doc/whats_new/v0.24.rst

+6
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ Changelog
110110
predictions for `est.transform(Y)` when the training data is single-target.
111111
:pr:`17095` by `Nicolas Hug`_.
112112

113+
- |API| For :class:`cross_decomposition.NMF`,
114+
the `init` value, when 'init=None' and
115+
n_components <= min(n_samples, n_features) will be changed from
116+
`'nndsvd'` to `'nndsvda'` in 0.26.
117+
:pr:`18525` by `Chiara Marmo <cmarmo>`.
118+
113119
- |API| The bounds of the `n_components` parameter is now restricted:
114120

115121
- into `[1, min(n_samples, n_features, n_targets)]`, for

sklearn/decomposition/_nmf.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _beta_loss_to_float(beta_loss):
247247
return beta_loss
248248

249249

250-
def _initialize_nmf(X, n_components, init=None, eps=1e-6,
250+
def _initialize_nmf(X, n_components, init='warn', eps=1e-6,
251251
random_state=None):
252252
"""Algorithms for NMF initialization.
253253
@@ -307,6 +307,13 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
307307
nonnegative matrix factorization - Pattern Recognition, 2008
308308
http://tinyurl.com/nndsvd
309309
"""
310+
if init == 'warn':
311+
warnings.warn(("The 'init' value, when 'init=None' and "
312+
"n_components is less than n_samples and "
313+
"n_features, will be changed from 'nndsvd' to "
314+
"'nndsvda' in 0.26."), FutureWarning)
315+
init = None
316+
310317
check_non_negative(X, "NMF initialization")
311318
n_samples, n_features = X.shape
312319

@@ -844,7 +851,7 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
844851

845852
@_deprecate_positional_args
846853
def non_negative_factorization(X, W=None, H=None, n_components=None, *,
847-
init=None, update_H=True, solver='cd',
854+
init='warn', update_H=True, solver='cd',
848855
beta_loss='frobenius', tol=1e-4,
849856
max_iter=200, alpha=0., l1_ratio=0.,
850857
regularization=None, random_state=None,
@@ -1253,7 +1260,7 @@ class NMF(TransformerMixin, BaseEstimator):
12531260
factorization with the beta-divergence. Neural Computation, 23(9).
12541261
"""
12551262
@_deprecate_positional_args
1256-
def __init__(self, n_components=None, *, init=None, solver='cd',
1263+
def __init__(self, n_components=None, *, init='warn', solver='cd',
12571264
beta_loss='frobenius', tol=1e-4, max_iter=200,
12581265
random_state=None, alpha=0., l1_ratio=0., verbose=0,
12591266
shuffle=False, regularization='both'):

sklearn/decomposition/tests/test_nmf.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,28 @@ def test_initialize_nn_output():
4242
def test_parameter_checking():
4343
A = np.ones((2, 2))
4444
name = 'spam'
45+
# FIXME : should be removed in 0.26
46+
init = 'nndsvda'
4547
msg = "Invalid solver parameter: got 'spam' instead of one of"
46-
assert_raise_message(ValueError, msg, NMF(solver=name).fit, A)
48+
assert_raise_message(ValueError, msg, NMF(solver=name, init=init).fit, A)
4749
msg = "Invalid init parameter: got 'spam' instead of one of"
4850
assert_raise_message(ValueError, msg, NMF(init=name).fit, A)
4951
msg = "Invalid regularization parameter: got 'spam' instead of one of"
50-
assert_raise_message(ValueError, msg, NMF(regularization=name).fit, A)
52+
assert_raise_message(ValueError, msg, NMF(regularization=name,
53+
init=init).fit, A)
5154
msg = "Invalid beta_loss parameter: got 'spam' instead of one"
52-
assert_raise_message(ValueError, msg, NMF(solver='mu',
55+
assert_raise_message(ValueError, msg, NMF(solver='mu', init=init,
5356
beta_loss=name).fit, A)
5457
msg = "Invalid beta_loss parameter: solver 'cd' does not handle "
5558
msg += "beta_loss = 1.0"
56-
assert_raise_message(ValueError, msg, NMF(solver='cd',
59+
assert_raise_message(ValueError, msg, NMF(solver='cd', init=init,
5760
beta_loss=1.0).fit, A)
5861

5962
msg = "Negative values in data passed to"
60-
assert_raise_message(ValueError, msg, NMF().fit, -A)
63+
assert_raise_message(ValueError, msg, NMF(init=init).fit, -A)
6164
assert_raise_message(ValueError, msg, nmf._initialize_nmf, -A,
6265
2, 'nndsvd')
63-
clf = NMF(2, tol=0.1).fit(A)
66+
clf = NMF(2, tol=0.1, init=init).fit(A)
6467
assert_raise_message(ValueError, msg, clf.transform, -A)
6568

6669
for init in ['nndsvd', 'nndsvda', 'nndsvdar']:
@@ -176,7 +179,9 @@ def test_n_components_greater_n_features():
176179
# Smoke test for the case of more components than features.
177180
rng = np.random.mtrand.RandomState(42)
178181
A = np.abs(rng.randn(30, 10))
179-
NMF(n_components=15, random_state=0, tol=1e-2).fit(A)
182+
# FIXME : should be removed in 0.26
183+
init = 'random'
184+
NMF(n_components=15, random_state=0, tol=1e-2, init=init).fit(A)
180185

181186

182187
@pytest.mark.parametrize('solver', ['cd', 'mu'])
@@ -214,7 +219,7 @@ def test_nmf_sparse_transform():
214219

215220
for solver in ('cd', 'mu'):
216221
model = NMF(solver=solver, random_state=0, n_components=2,
217-
max_iter=400)
222+
max_iter=400, init='nndsvd')
218223
A_fit_tr = model.fit_transform(A)
219224
A_tr = model.transform(A)
220225
assert_array_almost_equal(A_fit_tr, A_tr, decimal=1)
@@ -436,13 +441,17 @@ def test_nmf_regularization():
436441
rng = np.random.mtrand.RandomState(42)
437442
X = np.abs(rng.randn(n_samples, n_features))
438443

444+
# FIXME : should be removed in 0.26
445+
init = 'nndsvda'
439446
# L1 regularization should increase the number of zeros
440447
l1_ratio = 1.
441448
for solver in ['cd', 'mu']:
442449
regul = nmf.NMF(n_components=n_components, solver=solver,
443-
alpha=0.5, l1_ratio=l1_ratio, random_state=42)
450+
alpha=0.5, l1_ratio=l1_ratio, random_state=42,
451+
init=init)
444452
model = nmf.NMF(n_components=n_components, solver=solver,
445-
alpha=0., l1_ratio=l1_ratio, random_state=42)
453+
alpha=0., l1_ratio=l1_ratio, random_state=42,
454+
init=init)
446455

447456
W_regul = regul.fit_transform(X)
448457
W_model = model.fit_transform(X)
@@ -462,18 +471,20 @@ def test_nmf_regularization():
462471
l1_ratio = 0.
463472
for solver in ['cd', 'mu']:
464473
regul = nmf.NMF(n_components=n_components, solver=solver,
465-
alpha=0.5, l1_ratio=l1_ratio, random_state=42)
474+
alpha=0.5, l1_ratio=l1_ratio, random_state=42,
475+
init=init)
466476
model = nmf.NMF(n_components=n_components, solver=solver,
467-
alpha=0., l1_ratio=l1_ratio, random_state=42)
477+
alpha=0., l1_ratio=l1_ratio, random_state=42,
478+
init=init)
468479

469480
W_regul = regul.fit_transform(X)
470481
W_model = model.fit_transform(X)
471482

472483
H_regul = regul.components_
473484
H_model = model.components_
474485

475-
assert W_model.mean() > W_regul.mean()
476-
assert H_model.mean() > H_regul.mean()
486+
assert (linalg.norm(W_model))**2. + (linalg.norm(H_model))**2. > \
487+
(linalg.norm(W_regul))**2. + (linalg.norm(H_regul))**2.
477488

478489

479490
@ignore_warnings(category=ConvergenceWarning)
@@ -541,7 +552,9 @@ def test_nmf_dtype_match(dtype_in, dtype_out, solver, regularization):
541552
# Check that NMF preserves dtype (float32 and float64)
542553
X = np.random.RandomState(0).randn(20, 15).astype(dtype_in, copy=False)
543554
np.abs(X, out=X)
544-
nmf = NMF(solver=solver, regularization=regularization)
555+
# FIXME : should be removed in 0.26
556+
init = 'nndsvda'
557+
nmf = NMF(solver=solver, regularization=regularization, init=init)
545558

546559
assert nmf.fit(X).transform(X).dtype == dtype_out
547560
assert nmf.fit_transform(X).dtype == dtype_out
@@ -555,9 +568,13 @@ def test_nmf_float32_float64_consistency(solver, regularization):
555568
# Check that the result of NMF is the same between float32 and float64
556569
X = np.random.RandomState(0).randn(50, 7)
557570
np.abs(X, out=X)
558-
nmf32 = NMF(solver=solver, regularization=regularization, random_state=0)
571+
# FIXME : should be removed in 0.26
572+
init = 'nndsvda'
573+
nmf32 = NMF(solver=solver, regularization=regularization, random_state=0,
574+
init=init)
559575
W32 = nmf32.f 1241 it_transform(X.astype(np.float32))
560-
nmf64 = NMF(solver=solver, regularization=regularization, random_state=0)
576+
nmf64 = NMF(solver=solver, regularization=regularization, random_state=0,
577+
init=init)
561578
W64 = nmf64.fit_transform(X)
562579

563580
assert_allclose(W32, W64, rtol=1e-6, atol=1e-5)
@@ -576,3 +593,20 @@ def test_nmf_custom_init_dtype_error():
576593

577594
with pytest.raises(TypeError, match="should have the same dtype as X"):
578595
non_negative_factorization(X, H=H, update_H=False)
596+
597+
598+
# FIXME : should be removed in 0.26
599+
def test_init_default_deprecation():
600+
# Test FutureWarning on init default
601+
msg = ("The 'init' value, when 'init=None' and "
602+
"n_components is less than n_samples and "
603+
"n_features, will be changed from 'nndsvd' to "
604+
"'nndsvda' in 0.26.")
605+
rng = np.random.mtrand.RandomState(42)
606+
A = np.abs(rng.randn(6, 5))
607+
with pytest.warns(FutureWarning, match=msg):
608+
nmf._initialize_nmf(A, 3)
609+
with pytest.warns(FutureWarning, match=msg):
610+
NMF().fit(A)
611+
with pytest.warns(FutureWarning, match=msg):
612+
non_negative_factorization(A)

sklearn/tests/test_common.py

+6
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ class MyNMFWithBadErrorMessage(NMF):
221221
# Same as NMF but raises an uninformative error message if X has negative
222222
# value. This estimator would fail the check suite in strict mode,
223223
# specifically it would fail check_fit_non_negative
224+
# FIXME : should be removed in 0.26
225+
def __init__(self):
226+
super().__init__()
227+
self.init = 'nndsvda'
228+
self.max_iter = 500
229+
224230
def fit(self, X, y=None, **params):
225231
X = check_array(X, accept_sparse=('csr', 'csc'),
226232
dtype=[np.float64, np.float32])

sklearn/tests/test_docstring_parameters.py

+4
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ def test_fit_docstring_attributes(name, Estimator):
217217
if Estimator.__name__ == 'AffinityPropagation':
218218
est.random_state = 63
219219

220+
# TO BE REMOVED for v0.26 (avoid FutureWarning)
221+
if Estimator.__name__ == 'NMF':
222+
est.init = 'nndsvda'
223+
220224
X, y = make_classification(n_samples=20, n_features=3,
221225
n_redundant=0, n_classes=2,
222226
random_state=2)

sklearn/utils/estimator_checks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ def _set_checking_parameters(estimator):
608608
estimator.set_params(max_iter=20)
609609
# NMF
610610
if estimator.__class__.__name__ == 'NMF':
611-
estimator.set_params(max_iter=100)
611+
# FIXME : init should be removed in 0.26
612+
estimator.set_params(max_iter=500, init='nndsvda')
612613
# MLP
613614
if estimator.__class__.__name__ in ['MLPClassifier', 'MLPRegressor']:
614615
estimator.set_params(max_iter=100)

0 commit comments

Comments
 (0)
0