10000 Merge pull request #4251 from ogrisel/fix-shape-format · scikit-learn/scikit-learn@4f3613f · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f3613f

Browse files
committed
Merge pull request #4251 from ogrisel/fix-shape-format
[MRG] FIX broken test under 64 bit Python 2 / Windows
2 parents d90a8ae + 0373cc0 commit 4f3613f

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

sklearn/utils/validation.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# Alexandre Gramfort
77
# Nicolas Tresegnie
88
# License: BSD 3 clause
9-
109
import warnings
1110
import numbers
1211

@@ -119,6 +118,40 @@ def _num_samples(x):
119118
return len(x)
120119

121120

121+
def _shape_repr(shape):
122+
"""Return a platform independent reprensentation of an array shape
123+
124+
Under Python 2, the `long` type introduces an 'L' suffix when using the
125+
default %r format for tuples of integers (typically used to store the shape
126+
of an array).
127+
128+
Under Windows 64 bit (and Python 2), the `long` type is used by default
129+
in numpy shapes even when the integer dimensions are well below 32 bit.
130+
The platform specific type causes string messages or doctests to change
131+
from one platform to another which is not desirable.
132+
133+
Under Python 3, there is no more `long` type so the `L` suffix is never
134+
introduced in string representation.
135+
136+
>>> _shape_repr((1, 2))
137+
'(1, 2)'
138+
>>> one = 2 ** 64 / 2 ** 64 # force an upcast to `long` under Python 2
139+
>>> _shape_repr((one, 2 * one))
140+
'(1, 2)'
141+
>>> _shape_repr((1,))
142+
'(1,)'
143+
>>> _shape_repr(())
144+
'()'
145+
"""
146+
if len(shape) == 0:
147+
return "()"
148+
joined = ", ".join("%d" % e for e in shape)
149+
if len(shape) == 1:
150+
# special notation for singleton tuples
151+
joined += ','
152+
return "(%s)" % joined
153+
154+
122155
def check_consistent_length(*arrays):
123156
"""Check that all arrays have consistent first dimensions.
124157
@@ -295,19 +328,20 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
295328
if force_all_finite:
296329
_assert_all_finite(array)
297330

331+
shape_repr = _shape_repr(array.shape)
298332
if ensure_min_samples > 0:
299333
n_samples = _num_samples(array)
300334
if n_samples < ensure_min_samples:
301-
raise ValueError("Found array with %d sample(s) (shape=%r) while a"
335+
raise ValueError("Found array with %d sample(s) (shape=%s) while a"
302336
" minimum of %d is required."
303-
% (n_samples, array.shape, ensure_min_samples))
337+
% (n_samples, shape_repr, ensure_min_samples))
304338

305339
if ensure_min_features > 0 and ensure_2d and not allow_nd:
306340
n_features = array.shape[1]
307341
if n_features < ensure_min_features:
308-
raise ValueError("Found array with %d feature(s) (shape=%r) while"
342+
raise ValueError("Found array with %d feature(s) (shape=%s) while"
309343
" a minimum of %d is required."
310-
% (n_features, array.shape, ensure_min_features))
344+
% (n_features, shape_repr, ensure_min_features))
311345
return array
312346

313347

@@ -520,7 +554,7 @@ def check_symmetric(array, tol=1E-10, raise_warning=True,
520554
def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
521555
"""Perform is_fitted validation for estimator.
522556
523-
Checks if the estimator is fitted by verifying the presence of
557+
Checks if the estimator is fitted by verifying the presence of
524558
"all_or_any" of the passed attributes and raises a NotFittedError with the
525559
given message.
526560

0 commit comments

Comments
 (0)
0