8000 [MRG+1] Parameter check in NMF and IsolationForest (#6814) · scikit-learn/scikit-learn@4854454 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4854454

Browse files
nadya-pogrisel
authored andcommitted
[MRG+1] Parameter check in NMF and IsolationForest (#6814)
* FIX #6793 parameter check in NMF and IsolationForest * replaced six.integer_types with numbers.Integral * fix test_nmf * added an int64 test for non_negative_factorization * cosmetic change * dummy commit for triggering tests in github * added a unit test for iforest
1 parent df80c4c commit 4854454

File tree

4 files changed

+15
-12
lines changed

4 files changed

+15
-12
lines changed

sklearn/decomposition/nmf.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import numpy as np
2020
import scipy.sparse as sp
2121

22-
from ..externals import six
2322
from ..base import BaseEstimator, TransformerMixin
2423
from ..utils import check_random_state, check_array
2524
from ..utils.extmath import randomized_svd, safe_sparse_dot, squared_norm
@@ -747,11 +746,11 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
747746
if n_components is None:
748747
n_components = n_features
749748

750-
if not isinstance(n_components, six.integer_types) or n_components <= 0:
751-
raise ValueError("Number of components must be positive;"
749+
if not isinstance(n_components, numbers.Integral) or n_components <= 0:
750+
raise ValueError("Number of components must be a positive integer;"
752751
" got (n_components=%r)" % n_components)
753-
if not isinstance(max_iter, numbers.Number) or max_iter < 0:
754-
raise ValueError("Maximum number of iteration must be positive;"
752+
if not isinstance(max_iter, numbers.Integral) or max_iter < 0:
753+
raise ValueError("Maximum number of iterations must be a positive integer;"
755754
" got (max_iter=%r)" % max_iter)
756755
if not isinstance(tol, numbers.Number) or tol < 0:
757756
raise ValueError("Tolerance for stopping criteria must be "

sklearn/decomposition/tests/test_nmf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from sklearn.utils.testing import assert_true
99
from sklearn.utils.testing import assert_false
10-
from sklearn.utils.testing import assert_raise_message
10+
from sklearn.utils.testing import assert_raise_message, assert_no_warnings
1111
from sklearn.utils.testing import assert_array_almost_equal
1212
from sklearn.utils.testing import assert_almost_equal
1313
from sklearn.utils.testing import assert_greater
@@ -133,7 +133,6 @@ def test_nmf_transform_custom_init():
133133
t = m.transform(A)
134134

135135

136-
137136
@ignore_warnings
138137
def test_nmf_inverse_transform():
139138
# Test that NMF.inverse_transform returns close values
@@ -235,7 +234,10 @@ def test_non_negative_factorization_checking():
235234
A = np.ones((2, 2))
236235
# Test parameters checking is public function
237236
nnmf = non_negative_factorization
238-
msg = "Number of components must be positive; got (n_components='2')"
237+
assert_no_warnings(nnmf, A, A, A, np.int64(1))
238+
msg = "Number of components must be a positive integer; got (n_components=1.5)"
239+
assert_raise_message(ValueError, msg, nnmf, A, A, A, 1.5)
240+
msg = "Number of components must be a positive integer; got (n_components='2')"
239241
assert_raise_message(ValueError, msg, nnmf, A, A, A, '2')
240242
msg = "Negative values in data passed to NMF (input H)"
241243
assert_raise_message(ValueError, msg, nnmf, A, A, -A, 2, 'custom')

sklearn/ensemble/iforest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from scipy.sparse import issparse
1212

13+
import numbers
1314
from ..externals import six
1415
from ..tree import ExtraTreeRegressor
1516
from ..utils import check_random_state, check_array
@@ -167,7 +168,7 @@ def fit(self, X, y=None, sample_weight=None):
167168
'Valid choices are: "auto", int or'
168169
'float' % self.max_samples)
169170

170-
elif isinstance(self.max_samples, six.integer_types):
171+
elif isinstance(self.max_samples, numbers.Integral):
171172
if self.max_samples > n_samples:
172173
warn("max_samples (%s) is greater than the "
173174
"total number of samples (%s). max_samples "
@@ -277,7 +278,7 @@ def _average_path_length(n_samples_leaf):
277278
average_path_length : array, same shape as n_samples_leaf
278279
279280
"""
280-
if isinstance(n_samples_leaf, six.integer_types):
281+
if isinstance(n_samples_leaf, numbers.Integral):
281282
if n_samples_leaf <= 1:
282283
return 1.
283284
else:

sklearn/ensemble/tests/test_iforest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ def test_iforest_error():
104104
"max_samples will be set to n_samples for estimation",
105105
IsolationForest(max_samples=1000).fit, X)
106106
assert_no_warnings(IsolationForest(max_samples='auto').fit, X)
107-
assert_raises(ValueError,
108-
IsolationForest(max_samples='foobar').fit, X)
107+
assert_no_warnings(IsolationForest(max_samples=np.int64(2)).fit, X)
108+
assert_raises(ValueError, IsolationForest(max_samples='foobar').fit, X)
109+
assert_raises(ValueError, IsolationForest(max_samples=1.5).fit, X)
109110

110111

111112
def test_recalculate_max_depth():

0 commit comments

Comments
 (0)
0