8000 More robust unit tests for fast_dot. · r2k0/scikit-learn@9162dd5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9162dd5

Browse files
committed
More robust unit tests for fast_dot.
1 parent c0e686b commit 9162dd5

File tree

2 files changed

+49
-51
lines changed

2 files changed

+49
-51
lines changed

sklearn/utils/extmath.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -127,42 +127,43 @@ def _fast_dot(A, B):
127127
>> warnings.simplefilter('always', NonBLASDotWarning)
128128
"""
129129
if B.shape[0] != A.shape[A.ndim - 1]: # check adopted from '_dotblas.c'
130-
msg = ('Invalid array shapes: A.shape[%d] should be the same as '
131-
'B.shape[0]. Got A.shape=%r B.shape=%r' % (A.ndim - 1,
132-
A.shape,
133-
B.shape))
134-
raise ValueError(msg)
130+
raise ValueError
135131

136132
if A.dtype != B.dtype or any(x.dtype not in (np.float32, np.float64)
137133
for x in [A, B]):
138134
warnings.warn('Data must be of same type. Supported types '
139135
'are 32 and 64 bit float. '
140136
'Falling back to np.dot.', NonBLASDotWarning)
141-
return np.dot(A, B)
137+
raise ValueError
142138

143139
if min(A.shape) == 1 or min(B.shape) == 1 or A.ndim != 2 or B.ndim != 2:
144140
warnings.warn('Data must be 2D with more than one colum / row.'
145141
'Falling back to np.dot', NonBLASDotWarning)
146-
return np.dot(A, B)
142+
raise ValueError
147143

148144
# scipy 0.9 compliant API
149145
dot = linalg.get_blas_funcs(['gemm'], (A, B))[0]
150146
A, trans_a = _impose_f_order(A)
151147
B, trans_b = _impose_f_order(B)
152148
return dot(alpha=1.0, a=A, b=B, trans_a=trans_a, trans_b=trans_b)
153149

154-
# only try to use fast_dot for older numpy versions.
155-
# the related issue has been tackled meanwhile. Also, depending on the build
156-
# the current numpy master's dot can about 3 times faster.
157-
if LooseVersion(np.__version__) < '1.7.2': # backported
158-
try:
159-
linalg.get_blas_funcs(['gemm'])
160-
fast_dot = _fast_dot
161-
except (AttributeError, ValueError):
162-
fast_dot = np.dot
163-
warnings.warn('Could not import BLAS, falling back to np.dot')
164-
else:
165-
fast_dot = np.dot
150+
151+
def fast_dot(A, B):
152+
# only try to use fast_dot for older numpy versions.
153+
# the related issue has been tackled meanwhile. Also, depending on the build
154+
# the current numpy master's dot can about 3 times faster.
155+
if LooseVersion(np.__version__) < '1.7.2': # backported
156+
try:
157+
linalg.get_blas_funcs(['gemm'])
158+
try:
159+
return _fast_dot(A, B)
160+
except ValueError:
161+
return np.dot(A, B)
162+
except (AttributeError, ValueError):
163+
warnings.warn('Could not import BLAS, falling back to np.dot')
164+
return np.dot(A, B)
165+
else:
166+
return np.dot(A, B)
166167

167168

168169
def density(w, **kwargs):

sklearn/utils/tests/test_extmath.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.utils.extmath import weighted_mode
2525
from sklearn.utils.extmath import cartesian
2626
from sklearn.utils.extmath import logistic_sigmoid
27-
from sklearn.utils.extmath import fast_dot
27+
from sklearn.utils.extmath import fast_dot, _fast_dot
2828
from sklearn.utils.validation import NonBLASDotWarning
2929
from sklearn.datasets.samples_generator import make_low_rank_matrix
3030

@@ -316,66 +316,63 @@ def test_fast_dot():
316316
has_blas = False
317317

318318
if has_blas:
319-
# test dispatch to np.dot
320-
with warnings.catch_warnings(record=True) as w:
321-
warnings.simplefilter('always', NonBLASDotWarning)
322-
# maltyped data
319+
# Test _fast_dot for invalid input.
320+
321+
# Maltyped data.
323322
for dt1, dt2 in [['f8', 'f4'], ['i4', 'i4']]:
324-
fast_dot(A.astype(dt1), B.astype(dt2).T)
325-
assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning))
326-
# malformed data
327-
# ndim == 0
323+
assert_raises(ValueError, _fast_dot, A.astype(dt1),
324+
B.astype(dt2).T)
325+
326+
# Malformed data.
327+
328+
## ndim == 0
328329
E = np.empty(0)
329-
fast_dot(E, E)
330-
assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning))
330+
assert_raises(ValueError, _fast_dot, E, E)
331+
331332
## ndim == 1
332-
fast_dot(A, A[0])
333-
assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning))
333+
assert_raises(ValueError, _fast_dot, A, A[0])
334+
334335
## ndim > 2
335-
fast_dot(A.T, np.array([A, A]))
336-
assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning))
336+
assert_raises(ValueError, _fast_dot, A.T, np.array([A, A]))
337+
337338
## min(shape) == 1
338-
assert_raises(ValueError, fast_dot, A, A[0, :][None, :])
339-
# test for matrix mismatch error
340-
msg = ('Invalid array shapes: A.shape[%d] should be the same as '
341-
'B.shape[0]. Got A.shape=%r B.shape=%r' % (A.ndim - 1,
342-
A.shape,
343-
A.shape))
344-
assert_raise_message(ValueError, msg, fast_dot, A, A)
345-
346-
# test cov-like use case + dtypes
347-
my_assert = assert_array_almost_equal
339+
assert_raises(ValueError, _fast_dot, A, A[0, :][None, :])
340+
341+
# test for matrix mismatch error
342+
assert_raises(ValueError, _fast_dot, A, A)
343+
344+
# Test cov-like use case + dtypes.
348345
for dtype in ['f8', 'f4']:
349346
A = A.astype(dtype)
350347
B = B.astype(dtype)
351348

352349
# col < row
353350
C = np.dot(A.T, A)
354351
C_ = fast_dot(A.T, A)
355-
my_assert(C, C_)
352+
assert_almost_equal(C, C_)
356353

357354
C = np.dot(A.T, B)
358355
C_ = fast_dot(A.T, B)
359-
my_assert(C, C_)
356+
assert_almost_equal(C, C_)
360357

361358
C = np.dot(A, B.T)
362359
C_ = fast_dot(A, B.T)
363-
my_assert(C, C_)
360+
assert_almost_equal(C, C_)
364361

365-
# test square matrix * rectangular use case
362+
# Test square matrix * rectangular use case.
366363
A = rng.random_sample([2, 2])
367364
for dtype in ['f8', 'f4']:
368365
A = A.astype(dtype)
369366
B = B.astype(dtype)
370367

371368
C = np.dot(A, B)
372369
C_ = fast_dot(A, B)
373-
my_assert(C, C_)
370+
assert_almost_equal(C, C_)
374371

375372
C = np.dot(A.T, B)
376373
C_ = fast_dot(A.T, B)
377-
my_assert(C, C_)
374+
assert_almost_equal(C, C_)
378375

379376
if has_blas:
380377
for x in [np.array([[d] * 10] * 2) for d in [np.inf, np.nan]]:
381-
assert_raises(ValueError, fast_dot, x, x.T)
378+
assert_raises(ValueError, _fast_dot, x, x.T)

0 commit comments

Comments
 (0)
0