78
78
"assert_less" , "assert_less_equal" , "assert_greater" ,
79
79
"assert_greater_equal" , "assert_same_model" ,
80
80
"assert_not_same_model" , "assert_fitted_attributes_almost_equal" ,
81
- "assert_approx_equal" ]
81
+ "assert_approx_equal" , "assert_safe_sparse_allclose" ]
82
82
83
83
84
84
try :
@@ -387,41 +387,72 @@ def __exit__(self, *exc_info):
387
387
assert_greater = _assert_greater
388
388
389
389
390
- def _sparse_dense_allclose (val1 , val2 , rtol = 1e-7 , atol = 0 ):
390
+ if hasattr (np .testing , 'assert_allclose' ):
391
+ assert_allclose = np .testing .assert_allclose
392
+ else :
393
+ assert_allclose = _assert_allclose
394
+
395
+
396
+ def assert_safe_sparse_allclose (val1 , val2 , rtol = 1e-7 , atol = 0 , msg = None ):
391
397
"""Check if two objects are close up to the preset tolerance.
392
398
393
399
The objects can be scalars, lists, tuples, ndarrays or sparse matrices.
394
400
"""
395
- if isinstance (val1 , (int , float )) and isinstance (val2 , (int , float )):
396
- return np .allclose (float (val1 ), float (val2 ), rtol , atol )
401
+ if msg is None :
402
+ msg = ("The val1,\n %s\n and val2,\n %s\n are not all close"
403
+ % (val1 , val2 ))
404
+
405
+ if isinstance (val1 , str ) and isinstance (val2 , str ):
406
+ assert_true (val1 == val2 , msg = msg )
397
407
398
- if type (val1 ) is not type (val2 ):
399
- return False
408
+ elif np . isscalar (val1 ) and np . isscalar (val2 ):
409
+ assert_allclose ( val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
400
410
401
- comparables = (float , list , tuple , np .ndarray , sp .spmatrix )
411
+ # To allow mixed formats for sparse matrices alone
412
+ elif type (val1 ) is not type (val2 ) and not (
413
+ sp .issparse (val1 ) and sp .issparse (val2 )):
414
+ assert False , msg
402
415
403
- if not (isinstance (val1 , comparables ) or isinstance ( val2 , comparables )):
404
- raise ValueError ("The objects, %s and %s , are neither scalar nor "
416
+ elif not (isinstance (val1 , ( list , tuple , np . ndarray , sp . spmatrix , dict ) )):
417
+ raise ValueError ("The objects,\n %s \n and \n %s \n , are neither scalar nor "
405
418
"array-like." % (val1 , val2 ))
406
419
407
- # list/tuple (or list/tuple of ndarrays/spmatrices)
408
- if isinstance (val1 , (tuple , list )):
420
+ # list/tuple/dict (of list/tuple/dict...) of ndarrays/spmatrices/scalars
421
+ elif isinstance (val1 , (tuple , list , dict )):
422
+ if isinstance (val1 , dict ):
423
+ val1 , val2 = tuple (val1 .iteritems ()), tuple (val2 .iteritems ())
409
424
if (len (val1 ) == 0 ) and (len (val2 ) == 0 ):
410
- return True
411
- if len (val1 ) != len (val2 ):
412
- return False
413
- while isinstance (val1 [0 ], (tuple , list , np .ndarray , sp .spmatrix )):
414
- return all (_sparse_dense_allclose (val1_i , val2 [i ], rtol , atol )
415
- for i , val1_i in enumerate (val1 ))
416
- # Compare the lists, if they are not nested or singleton
417
- return np .allclose (val1 , val2 , rtol , atol )
418
-
419
- same_shape = val1 .shape == val2 .shape
420
- if sp .issparse (val1 ) or sp .issparse (val2 ):
421
- return same_shape and np .allclose (val1 .toarray (), val2 .toarray (),
422
- rtol , atol )
425
+ assert True
426
+ elif len (val1 ) != len (val2 ):
427
+ assert False , msg
428
+ # nested lists/tuples - [array([5, 6]), array([5, ])] and [[1, 3], ]
429
+ # Or ['str',] and ['str',]
430
+ elif isinstance (val1 [0 ], (tuple , list , np .ndarray , sp .spmatrix , str )):
431
+ # Compare them recursively
432
+ for i , val1_i in enumerate (val1 ):
433
+ assert_safe_sparse_allclose (val1_i , val2 [i ],
434
+ rtol = rtol , atol = atol , msg = msg )
435
+ # Compare the lists using np.allclose, if they are neither nested nor
436
+ # contain strings
437
+ else :
438
+ assert_allclose (val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
439
+
440
+ # scipy sparse matrix
441
+ elif sp .issparse (val1 ) or sp .issparse (val2 ):
442
+ # NOTE: ref np.allclose's note for assymetricity in this testing
443
+ if val1 .shape != val2 .shape :
444
+ assert False , msg
445
+
446
+ diff = abs (val1 - val2 ) - (rtol * abs (val2 ))
447
+ assert np .any (diff > atol ).size == 0 , msg
448
+
449
+ # numpy ndarray
450
+ elif isinstance (val1 , (np .ndarray )):
451
+ if val1 .shape != val2 .shape :
452
+ assert False , msg
453
+ assert_allclose (val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
423
454
else :
424
- return same_shape and np . allclose ( val1 , val2 , rtol , atol )
455
+ assert False , msg
425
456
426
457
427
458
def _assert_allclose (actual , desired , rtol = 1e-7 , atol = 0 ,
@@ -435,12 +466,6 @@ def _assert_allclose(actual, desired, rtol=1e-7, atol=0,
435
466
raise AssertionError (err_msg )
436
467
437
468
438
- if hasattr (np .testing , 'assert_allclose' ):
439
- assert_allclose = np .testing .assert_allclose
440
- else :
441
- assert_allclose = _assert_allclose
442
-
443
-
444
469
def assert_raise_message (exceptions , message , function , * args , ** kwargs ):
445
470
"""Helper function to test error messages in exceptions.
446
471
@@ -488,12 +513,11 @@ def _assert_same_model_method(method, X, estimator1, estimator2, msg=None):
488
513
489
514
# Check if the method(X) returns the same for both models.
490
515
res1 , res2 = getattr (estimator1 , method )(X ), getattr (estimator2 , method )(X )
491
- if not _sparse_dense_allclose (res1 , res2 ):
492
- if msg is None :
493
- msg = ("Models are not equal. \n \n %s method returned different "
494
- "results:\n \n %s\n \n for :\n \n %s and\n \n %s\n \n for :\n \n %s."
495
- % (method , res1 , estimator1 , res2 , estimator2 ))
496
- raise AssertionError (msg )
516
+ if msg is None :
517
+ msg = ("Models are not equal. \n \n %s method returned different "
518
+ "results:\n \n %s\n \n for :\n \n %s and\n \n %s\n \n for :\n \n %s."
519
+ % (method , res1 , estimator1 , res2 , estimator2 ))
520
+ assert_safe_sparse_allclose (res1 , res2 , msg = msg )
497
521
498
522
499
523
def assert_same_model (X , estimator1 , estimator2 , msg = None ):
@@ -579,9 +603,8 @@ def assert_not_same_model(X, estimator1, estimator2, msg=None):
579
603
try :
580
604
assert_same_model (X , estimator1 , estimator2 )
581
605
except AssertionError :
582
- pass
583
- else :
584
- raise AssertionError (msg )
606
+ return
607
+ raise AssertionError (msg )
585
608
586
609
587
610
def assert_fitted_attributes_almost_equal (estimator1 , estimator2 , msg = None ):
@@ -616,23 +639,21 @@ def assert_fitted_attributes_almost_equal(estimator1, estimator2, msg=None):
616
639
"The attributes of both the estimators do not match." )
617
640
618
641
non_attributes = ("estimators_" , "estimator_" , "tree_" , "base_estimator_" ,
619
- "random_state_" )
642
+ "random_state_" , "root_" , "label_binarizer_" , "loss_" )
643
+ non_attr_suffixes = ("leaf_" ,)
644
+
620
645
for attr in est1_dict :
621
646
val1 , val2 = est1_dict [attr ], est2_dict [attr ]
622
647
623
648
# Consider keys that end in ``_`` only as attributes.
624
- if (attr .endswith ('_' ) and attr not in non_attributes ):
649
+ if (attr .endswith ('_' ) and attr not in non_attributes and
650
+ not attr .endswith (non_attr_suffixes )):
625
651
if msg is None :
626
652
msg = ("Attributes do not match. \n The attribute, %s, in "
627
653
"estimator1,\n \n %r\n \n is %r and in estimator2,"
628
654
"\n \n %r\n \n is %r.\n " ) % (attr , estimator1 , val1 ,
629
655
estimator2 , val2 )
630
- if isinstance (val1 , str ) and isinstance (val2 , str ):
631
- attr_similar = val1 == val2
632
- else :
633
- attr_similar = _sparse_dense_allclose (val1 , val2 )
634
- if not attr_similar :
635
- raise AssertionError (msg )
656
+ assert_safe_sparse_allclose (val1 , val2 , msg = msg )
636
657
637
658
638
659
def fake_mldata (columns_dict , dataname , matfile , ordering = None ):
0 commit comments