8000 r · scikit-learn/scikit-learn@b20133c · GitHub
[go: up one dir, main page]

Skip to content

Commit b20133c

Browse files
committed
r
1 parent 52f06ee commit b20133c

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

sklearn/decomposition/nmf.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from math import sqrt
1414
import warnings
15+
import numbers
1516

1617
import numpy as np
1718
import scipy.sparse as sp
@@ -60,10 +61,11 @@ def check_non_negative(X, whom):
6061
def _check_init(A, shape, whom):
6162
A = check_array(A)
6263
if np.shape(A) != shape:
63-
raise ValueError('TODO')
64+
raise ValueError('Array with wrong shape passed to %s. Expected %s, '
65+
'but got %s ' % (whom, shape, np.shape(A)))
6466
check_non_negative(A, whom)
6567
if np.max(A) == 0:
66-
raise ValueError('TODO')
68+
raise ValueError('Array passed to %s is full of zeros.' % whom)
6769

6870

6971
def _safe_compute_error(X, W, H):
@@ -227,7 +229,8 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
227229
else:
228230
raise ValueError(
229231
'Invalid init parameter: got %r instead of one of %r' %
230-
(init, (None, 'nndsvd', 'nndsvda', 'nndsvdar', 'random')))
232+
(init, (None, 'random', 'nndsvd', 'nndsvda', 'nndsvdar',
233+
'uniform')))
231234

232235
return W, H
233236

@@ -680,6 +683,16 @@ def non_negative_matrix_factorization(X, W=None, H=None, n_components=None,
680683
if n_components is None:
681684
n_components = n_features
682685

686+
if not isinstance(n_components, int) or n_components <= 0:
687+
raise ValueError("Number of components must be positive;"
688+
" got (n_components=%r)" % n_components)
689+
if not isinstance(max_iter, numbers.Number) or max_iter < 0:
690+
raise ValueError("Maximum number of iteration must be positive;"
691+
" got (max_iter=%r)" % max_iter)
692+
if not isinstance(tol, numbers.Number) or tol < 0:
693+
raise ValueError("Tolerance for stopping criteria must be "
694+
"positive; got (tol=%r)" % tol)
695+
683696
# check W and H, or initialize them
684697
if init == 'custom':
685698
_check_init(H, (n_components, n_features), "NMF (input H)")
@@ -920,7 +933,6 @@ def transform(self, X):
920933
Transformed data
921934
"""
922935
check_is_fitted(self, 'n_components_')
923-
X = check_array(X, accept_sparse='csc')
924936

925937
W, _, _ = non_negative_matrix_factorization(
926938
X=X, W=None, H=self.components_, n_components=self.n_components_,

sklearn/decomposition/tests/test_nmf.py

+32
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,35 @@ def test_sparse_transform():
177177
A_fit_tr = model.fit_transform(A)
178178
A_tr = model.transform(A)
179179
assert_array_almost_equal(A_fit_tr, A_tr, decimal=2)
180+
181+
182+
def test_non_negative_matrix_factorization_path():
183+
# Test path consistency between the class and the public function
184+
A = np.abs(random_state.randn(10, 10))
185+
A[:, 2 * np.arange(5)] = 0
186+
187+
for solver in ('proj-grad', 'coordinate'):
188+
W_nmf, H, _ = nmf.non_negative_matrix_factorization(
189+
A, solver=solver, random_state=1, tol=1e-2)
190+
W_nmf_bis, H, _ = nmf.non_negative_matrix_factorization(
191+
A, H=H, update_H=False, solver=solver, random_state=1, tol=1e-2)
192+
193+
model_class = nmf.NMF(solver=solver, random_state=1, tol=1e-2)
194+
W_cls = model_class.fit_transform(A)
195+
W_cls_bis = model_class.transform(A)
196+
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
197+
assert_array_almost_equal(W_nmf_bis, W_cls_bis, decimal=10)
198+
199+
200+
def test_non_negative_matrix_factorization_checking():
201+
A = np.ones((2, 2))
202+
# Test parameters checking is public function
203+
nnmf = nmf.non_negative_matrix_factorization
204+
msg = "Number of components must be positive; got (n_components='2')"
205+
assert_raise_message(ValueError, msg, nnmf, A, A, A, '2')
206+
msg = "Negative values in data passed to NMF (input H)"
207+
assert_raise_message(ValueError, msg, nnmf, A, A, -A, 2, 'custom')
208+
msg = "Negative values in data passed to NMF (input W)"
209+
assert_raise_message(ValueError, msg, nnmf, A, -A, A, 2, 'custom')
210+
msg = "Array passed to NMF (input H) is full of zeros"
211+
assert_raise_message(ValueError, msg, nnmf, A, A, 0 * A, 2, 'custom')

0 commit comments

Comments
 (0)
0