|
1 | 1 | """Tests for input validation functions"""
|
2 | 2 |
|
3 | 3 | from tempfile import NamedTemporaryFile
|
| 4 | +from itertools import product |
| 5 | + |
4 | 6 | import numpy as np
|
5 | 7 | from numpy.testing import assert_array_equal, assert_warns
|
6 | 8 | import scipy.sparse as sp
|
7 | 9 | from nose.tools import assert_raises, assert_true, assert_false, assert_equal
|
8 |
| -from itertools import product |
9 | 10 |
|
| 11 | +from sklearn.utils.testing import assert_raises_regexp |
10 | 12 | from sklearn.utils import as_float_array, check_array, check_symmetric
|
11 | 13 | from sklearn.utils import check_X_y
|
12 |
| - |
13 | 14 | from sklearn.utils.estimator_checks import NotAnArray
|
14 |
| - |
15 | 15 | from sklearn.random_projection import sparse_random_matrix
|
16 |
| - |
17 | 16 | from sklearn.linear_model import ARDRegression
|
18 | 17 | from sklearn.neighbors import KNeighborsClassifier
|
19 | 18 | from sklearn.ensemble import RandomForestRegressor
|
20 | 19 | from sklearn.svm import SVR
|
21 |
| - |
22 | 20 | from sklearn.datasets import make_blobs
|
23 | 21 | from sklearn.utils.validation import (
|
24 | 22 | NotFittedError,
|
25 | 23 | has_fit_parameter,
|
26 |
| - check_is_fitted) |
| 24 | + check_is_fitted, |
| 25 | + check_consistent_length) |
27 | 26 |
|
28 | 27 | from sklearn.utils.testing import assert_raise_message
|
29 | 28 |
|
@@ -337,3 +336,21 @@ def test_check_is_fitted():
|
337 | 336 |
|
338 | 337 | assert_equal(None, check_is_fitted(ard, "coef_"))
|
339 | 338 | assert_equal(None, check_is_fitted(svr, "support_"))
|
| 339 | + |
| 340 | + |
| 341 | +def test_check_consistent_length(): |
| 342 | + check_consistent_length([1], [2], [3], [4], [5]) |
| 343 | + check_consistent_length([[1, 2], [[1, 2]]], [1, 2], ['a', 'b']) |
| 344 | + check_consistent_length([1], (2,), np.array([3]), sp.csr_matrix((1, 2))) |
| 345 | + assert_raises_regexp(ValueError, 'inconsistent numbers of samples', |
| 346 | + check_consistent_length, [1, 2], [1]) |
| 347 | + assert_raises_regexp(TypeError, 'got <\w+ \'int\'>', |
| 348 | + check_consistent_length, [1, 2], 1) |
| 349 | + assert_raises_regexp(TypeError, 'got <\w+ \'object\'>', |
| 350 | + check_consistent_length, [1, 2], object()) |
| 351 | + |
| 352 | + assert_raises(TypeError, check_consistent_length, [1, 2], np.array(1)) |
| 353 | + # Despite ensembles having __len__ they must raise TypeError |
| 354 | + assert_raises_regexp(TypeError, 'estimator', check_consistent_length, |
| 355 | + [1, 2], RandomForestRegressor()) |
| 356 | + # XXX: We should have a test with a string, but what is correct behaviour? |
0 commit comments