|
24 | 24 | from sklearn.utils.extmath import weighted_mode
|
25 | 25 | from sklearn.utils.extmath import cartesian
|
26 | 26 | from sklearn.utils.extmath import logistic_sigmoid
|
27 |
| -from sklearn.utils.extmath import fast_dot |
| 27 | +from sklearn.utils.extmath import fast_dot, _fast_dot |
28 | 28 | from sklearn.utils.validation import NonBLASDotWarning
|
29 | 29 | from sklearn.datasets.samples_generator import make_low_rank_matrix
|
30 | 30 |
|
@@ -316,66 +316,63 @@ def test_fast_dot():
|
316 | 316 | has_blas = False
|
317 | 317 |
|
318 | 318 | if has_blas:
|
319 |
| - # test dispatch to np.dot |
320 |
| - with warnings.catch_warnings(record=True) as w: |
321 |
| - warnings.simplefilter('always', NonBLASDotWarning) |
322 |
| - # maltyped data |
| 319 | + # Test _fast_dot for invalid input. |
| 320 | + |
| 321 | + # Maltyped data. |
323 | 322 | for dt1, dt2 in [['f8', 'f4'], ['i4', 'i4']]:
|
324 |
| - fast_dot(A.astype(dt1), B.astype(dt2).T) |
325 |
| - assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning)) |
326 |
| - # malformed data |
327 |
| - # ndim == 0 |
| 323 | + assert_raises(ValueError, _fast_dot, A.astype(dt1), |
| 324 | + B.astype(dt2).T) |
| 325 | + |
| 326 | + # Malformed data. |
| 327 | + |
| 328 | + ## ndim == 0 |
328 | 329 | E = np.empty(0)
|
329 |
| - fast_dot(E, E) |
330 |
| - assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning)) |
| 330 | + assert_raises(ValueError, _fast_dot, E, E) |
| 331 | + |
331 | 332 | ## ndim == 1
|
332 |
| - fast_dot(A, A[0]) |
333 |
| - assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning)) |
| 333 | + assert_raises(ValueError, _fast_dot, A, A[0]) |
| 334 | + |
334 | 335 | ## ndim > 2
|
335 |
| - fast_dot(A.T, np.array([A, A])) |
336 |
| - assert_true(isinstance(w.pop(-1).message, NonBLASDotWarning)) |
| 336 | + assert_raises(ValueError, _fast_dot, A.T, np.array([A, A])) |
| 337 | + |
337 | 338 | ## min(shape) == 1
|
338 |
| - assert_raises(ValueError, fast_dot, A, A[0, :][None, :]) |
339 |
| - # test for matrix mismatch error |
340 |
| - msg = ('Invalid array shapes: A.shape[%d] should be the same as ' |
341 |
| - 'B.shape[0]. Got A.shape=%r B.shape=%r' % (A.ndim - 1, |
342 |
| - A.shape, |
343 |
| - A.shape)) |
344 |
| - assert_raise_message(ValueError, msg, fast_dot, A, A) |
345 |
| - |
346 |
| - # test cov-like use case + dtypes |
347 |
| - my_assert = assert_array_almost_equal |
| 339 | + assert_raises(ValueError, _fast_dot, A, A[0, :][None, :]) |
| 340 | + |
| 341 | + # test for matrix mismatch error |
| 342 | + assert_raises(ValueError, _fast_dot, A, A) |
| 343 | + |
| 344 | + # Test cov-like use case + dtypes. |
348 | 345 | for dtype in ['f8', 'f4']:
|
349 | 346 | A = A.astype(dtype)
|
350 | 347 | B = B.astype(dtype)
|
351 | 348 |
|
352 | 349 | # col < row
|
353 | 350 | C = np.dot(A.T, A)
|
354 | 351 | C_ = fast_dot(A.T, A)
|
355 |
| - my_assert(C, C_) |
| 352 | + assert_almost_equal(C, C_) |
356 | 353 |
|
357 | 354 | C = np.dot(A.T, B)
|
358 | 355 | C_ = fast_dot(A.T, B)
|
359 |
| - my_assert(C, C_) |
| 356 | + assert_almost_equal(C, C_) |
360 | 357 |
|
361 | 358 | C = np.dot(A, B.T)
|
362 | 359 | C_ = fast_dot(A, B.T)
|
363 |
| - my_assert(C, C_) |
| 360 | + assert_almost_equal(C, C_) |
364 | 361 |
|
365 |
| - # test square matrix * rectangular use case |
| 362 | + # Test square matrix * rectangular use case. |
366 | 363 | A = rng.random_sample([2, 2])
|
367 | 364 | for dtype in ['f8', 'f4']:
|
368 | 365 | A = A.astype(dtype)
|
369 | 366 | B = B.astype(dtype)
|
370 | 367 |
|
371 | 368 | C = np.dot(A, B)
|
372 | 369 | C_ = fast_dot(A, B)
|
373 |
| - my_assert(C, C_) |
| 370 | + assert_almost_equal(C, C_) |
374 | 371 |
|
375 | 372 | C = np.dot(A.T, B)
|
376 | 373 | C_ = fast_dot(A.T, B)
|
377 |
| - my_assert(C, C_) |
| 374 | + assert_almost_equal(C, C_) |
378 | 375 |
|
379 | 376 | if has_blas:
|
380 | 377 | for x in [np.array([[d] * 10] * 2) for d in [np.inf, np.nan]]:
|
381 |
| - assert_raises(ValueError, fast_dot, x, x.T) |
| 378 | + assert_raises(ValueError, _fast_dot, x, x.T) |
0 commit comments