8000 MAINT/TST: public export of non_negative_factorization · johannah/scikit-learn@3697d75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3697d75

Browse files
committed
MAINT/TST: public export of non_negative_factorization
Changed imports in test to separate testing of API and of internals. See scikit-learngh-5509.
1 parent ee0fd9a commit 3697d75

File tree

2 files changed

+29
-29
lines changed

2 files changed

+29
-29
lines changed

sklearn/decomposition/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
this module can be regarded as dimensionality reduction techniques.
55
"""
66

7-
from .nmf import NMF, ProjectedGradientNMF
7+
from .nmf import NMF, ProjectedGradientNMF, non_negative_factorization
88
from .pca import PCA, RandomizedPCA
99
from .incremental_pca import IncrementalPCA
1010
from .kernel_pca import KernelPCA
@@ -33,6 +33,7 @@
3333
'dict_learning',
3434
'dict_learning_online',
3535
'fastica',
36+
'non_negative_factorization',
3637
'randomized_svd',
3738
'sparse_encode',
3839
'FactorAnalysis',

sklearn/decomposition/tests/test_nmf.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
from scipy import linalg
3-
from sklearn.decomposition import nmf
3+
from sklearn.decomposition import (NMF, ProjectedGradientNMF,
4+
non_negative_factorization)
5+
from sklearn.decomposition import nmf # For testing internals
46
from scipy.sparse import csc_matrix
57

68
from sklearn.utils.testing import assert_true
@@ -30,17 +32,17 @@ def test_parameter_checking():
3032
A = np.ones((2, 2))
3133
name = 'spam'
3234
msg = "Invalid solver parameter: got 'spam' instead of one of"
33-
assert_raise_message(ValueError, msg, nmf.NMF(solver=name).fit, A)
35+
assert_raise_message(ValueError, msg, NMF(solver=name).fit, A)
3436
msg = "Invalid init parameter: got 'spam' instead of one of"
35-
assert_raise_message(ValueError, msg, nmf.NMF(init=name).fit, A)
37+
assert_raise_message(ValueError, msg, NMF(init=name).fit, A)
3638
msg = "Invalid sparseness parameter: got 'spam' instead of one of"
37-
assert_raise_message(ValueError, msg, nmf.NMF(sparseness=name).fit, A)
39+
assert_raise_message(ValueError, msg, NMF(sparseness=name).fit, A)
3840

3941
msg = "Negative values in data passed to"
40-
assert_raise_message(ValueError, msg, nmf.NMF().fit, -A)
42+
assert_raise_message(ValueError, msg, NMF().fit, -A)
4143
assert_raise_message(ValueError, msg, nmf._initialize_nmf, -A,
4244
2, 'nndsvd')
43-
clf = nmf.NMF(2, tol=0.1).fit(A)
45+
clf = NMF(2, tol=0.1).fit(A)
4446
assert_raise_message(ValueError, msg, clf.transform, -A)
4547

4648

@@ -76,8 +78,8 @@ def test_nmf_fit_nn_output():
7678
5 * np.ones(5) + np.arange(1, 6)]
7779
for solver in ('pg', 'cd'):
7880
for init in (None, 'nndsvd', 'nndsvda', 'nndsvdar'):
79-
model = nmf.NMF(n_components=2, solver=solver, init=init,
80-
random_state=0)
81+
model = NMF(n_components=2, solver=solver, init=init,
82+
random_state=0)
8183
transf = model.fit_transform(A)
8284
assert_false((model.components_ < 0).any() or
8385
(transf < 0).any())
@@ -87,7 +89,7 @@ def test_nmf_fit_nn_output():
8789
def test_nmf_fit_close():
8890
# Test that the fit is not too far away
8991
for solver in ('pg', 'cd'):
90-
pnmf = nmf.NMF(5, solver=solver, init='nndsvd', random_state=0)
92+
pnmf = NMF(5, solver=solver, init='nndsvd', random_state=0)
9193
X = np.abs(random_state.randn(6, 5))
9294
assert_less(pnmf.fit(X).reconstruction_err_, 0.05)
9395

@@ -112,8 +114,7 @@ def test_nmf_transform():
112114
# Test that NMF.transform returns close values
113115
A = np.abs(random_state.randn(6, 5))
114116
for solver in ('pg', 'cd'):
115-
m = nmf.NMF(solver=solver, n_components=4, init='nndsvd',
116-
random_state=0)
117+
m = NMF(solver=solver, n_components=4, init='nndsvd', random_state=0)
117118
ft = m.fit_transform(A)
118119
t = m.transform(A)
119120
assert_array_almost_equal(ft, t, decimal=2)
@@ -123,7 +124,7 @@ def test_nmf_transform():
123124
def test_n_components_greater_n_features():
124125
# Smoke test for the case of more components than features.
125126
A = np.abs(random_state.randn(30, 10))
126-
nmf.NMF(n_components=15, random_state=0, tol=1e-2).fit(A)
127+
NMF(n_components=15, random_state=0, tol=1e-2).fit(A)
127128

128129

129130
@ignore_warnings
@@ -133,14 +134,13 @@ def test_projgrad_nmf_sparseness():
133134
# part where they are applied.
134135
tol = 1e-2
135136
A = np.abs(random_state.randn(10, 10))
136-
m = nmf.ProjectedGradientNMF(n_components=5, random_state=0,
137-
tol=tol).fit(A)
138-
data_sp = nmf.ProjectedGradientNMF(n_components=5, sparseness='data',
139-
random_state=0,
140-
tol=tol).fit(A).data_sparseness_
141-
comp_sp = nmf.ProjectedGradientNMF(n_components=5, sparseness='components',
142-
random_state=0,
143-
tol=tol).fit(A).comp_sparseness_
137+
m = ProjectedGradientNMF(n_components=5, random_state=0, tol=tol).fit(A)
138+
data_sp = ProjectedGradientNMF(n_components=5, sparseness='data',
139+
random_state=0,
140+
tol=tol).fit(A).data_sparseness_
141+
comp_sp = ProjectedGradientNMF(n_components=5, sparseness='components',
142+
random_state=0,
143+
tol=tol).fit(A).comp_sparseness_
144144
assert_greater(data_sp, m.data_sparseness_)
145145
assert_greater(comp_sp, m.comp_sparseness_)
146146

@@ -155,8 +155,8 @@ def test_sparse_input():
155155
A_sparse = csc_matrix(A)
156156

157157
for solver in ('pg', 'cd'):
158-
est1 = nmf.NMF(solver=solver, n_components=5, init='random',
159-
random_state=0, tol=1e-2)
158+
est1 = NMF(solver=solver, n_components=5, init='random',
159+
random_state=0, tol=1e-2)
160160
est2 = clone(est1)
161161

162162
W1 = est1.fit_transform(A)
@@ -177,8 +177,7 @@ def test_sparse_transform():
177177
A = csc_matrix(A)
178178

179179
for solver in ('pg', 'cd'):
180-
model = nmf.NMF(solver=solver, random_state=0, tol=1e-4,
181-
n_components=2)
180+
model = NMF(solver=solver, random_state=0, tol=1e-4, n_components=2)
182181
A_fit_tr = model.fit_transform(A)
183182
A_tr = model.transform(A)
184183
assert_array_almost_equal(A_fit_tr, A_tr, decimal=1)
@@ -192,12 +191,12 @@ def test_non_negative_factorization_consistency():
192191
A[:, 2 * np.arange(5)] = 0
193192

194193
for solver in ('pg', 'cd'):
195-
W_nmf, H, _ = nmf.non_negative_factorization(
194+
W_nmf, H, _ = non_negative_factorization(
196195
A, solver=solver, random_state=1, tol=1e-2)
197-
W_nmf_2, _, _ = nmf.non_negative_factorization(
196+
W_nmf_2, _, _ = non_negative_factorization(
198197
A, H=H, update_H=False, solver=solver, random_state=1, tol=1e-2)
199198

200-
model_class = nmf.NMF(solver=solver, random_state=1, tol=1e-2)
199+
model_class = NMF(solver=solver, random_state=1, tol=1e-2)
201200
W_cls = model_class.fit_transform(A)
202201
W_cls_2 = model_class.transform(A)
203202
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
@@ -208,7 +207,7 @@ def test_non_negative_factorization_consistency():
208207
def test_non_negative_factorization_checking():
209208
A = np.ones((2, 2))
210209
# Test parameters checking is public function
211-
nnmf = nmf.non_negative_factorization
210+
nnmf = non_negative_factorization
212211
msg = "Number of components must be positive; got (n_components='2')"
213212
assert_raise_message(ValueError, msg, nnmf, A, A, A, '2')
214213
msg = "Negative values in data passed to NMF (input H)"

0 commit comments

Comments
 (0)
0