|
6 | 6 | # Alexandre Gramfort
|
7 | 7 | # Nicolas Tresegnie
|
8 | 8 | # License: BSD 3 clause
|
9 |
| - |
10 | 9 | import warnings
|
11 | 10 | import numbers
|
12 | 11 |
|
@@ -119,6 +118,40 @@ def _num_samples(x):
|
119 | 118 | return len(x)
|
120 | 119 |
|
121 | 120 |
|
| 121 | +def _shape_repr(shape): |
| 122 | + """Return a platform independent reprensentation of an array shape |
| 123 | +
|
| 124 | + Under Python 2, the `long` type introduces an 'L' suffix when using the |
| 125 | + default %r format for tuples of integers (typically used to store the shape |
| 126 | + of an array). |
| 127 | +
|
| 128 | + Under Windows 64 bit (and Python 2), the `long` type is used by default |
| 129 | + in numpy shapes even when the integer dimensions are well below 32 bit. |
| 130 | + The platform specific type causes string messages or doctests to change |
| 131 | + from one platform to another which is not desirable. |
| 132 | +
|
| 133 | + Under Python 3, there is no more `long` type so the `L` suffix is never |
| 134 | + introduced in string representation. |
| 135 | +
|
| 136 | + >>> _shape_repr((1, 2)) |
| 137 | + '(1, 2)' |
| 138 | + >>> one = 2 ** 64 / 2 ** 64 # force an upcast to `long` under Python 2 |
| 139 | + >>> _shape_repr((one, 2 * one)) |
| 140 | + '(1, 2)' |
| 141 | + >>> _shape_repr((1,)) |
| 142 | + '(1,)' |
| 143 | + >>> _shape_repr(()) |
| 144 | + '()' |
| 145 | + """ |
| 146 | + if len(shape) == 0: |
| 147 | + return "()" |
| 148 | + joined = ", ".join("%d" % e for e in shape) |
| 149 | + if len(shape) == 1: |
| 150 | + # special notation for singleton tuples |
| 151 | + joined += ',' |
| 152 | + return "(%s)" % joined |
| 153 | + |
| 154 | + |
122 | 155 | def check_consistent_length(*arrays):
|
123 | 156 | """Check that all arrays have consistent first dimensions.
|
124 | 157 |
|
@@ -295,19 +328,20 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
|
295 | 328 | if force_all_finite:
|
296 | 329 | _assert_all_finite(array)
|
297 | 330 |
|
| 331 | + shape_repr = _shape_repr(array.shape) |
298 | 332 | if ensure_min_samples > 0:
|
299 | 333 | n_samples = _num_samples(array)
|
300 | 334 | if n_samples < ensure_min_samples:
|
301 |
| - raise ValueError("Found array with %d sample(s) (shape=%r) while a" |
| 335 | + raise ValueError("Found array with %d sample(s) (shape=%s) while a" |
302 | 336 | " minimum of %d is required."
|
303 |
| - % (n_samples, array.shape, ensure_min_samples)) |
| 337 | + % (n_samples, shape_repr, ensure_min_samples)) |
304 | 338 |
|
305 | 339 | if ensure_min_features > 0 and ensure_2d and not allow_nd:
|
306 | 340 | n_features = array.shape[1]
|
307 | 341 | if n_features < ensure_min_features:
|
308 |
| - raise ValueError("Found array with %d feature(s) (shape=%r) while" |
| 342 | + raise ValueError("Found array with %d feature(s) (shape=%s) while" |
309 | 343 | " a minimum of %d is required."
|
310 |
| - % (n_features, array.shape, ensure_min_features)) |
| 344 | + % (n_features, shape_repr, ensure_min_features)) |
311 | 345 | return array
|
312 | 346 |
|
313 | 347 |
|
@@ -520,7 +554,7 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
|
520 | 554 | def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
|
521 | 555 | """Perform is_fitted validation for estimator.
|
522 | 556 |
|
523 |
| - Checks if the estimator is fitted by verifying the presence of |
| 557 | + Checks if the estimator is fitted by verifying the presence of |
524 | 558 | "all_or_any" of the passed attributes and raises a NotFittedError with the
|
525 | 559 | given message.
|
526 | 560 |
|
|
0 commit comments