77
77
"assert_less" , "assert_less_equal" , "assert_greater" ,
78
78
"assert_greater_equal" , "assert_same_model" ,
79
79
"assert_not_same_model" , "assert_fitted_attributes_almost_equal" ,
80
- "assert_approx_equal" ]
80
+ "assert_approx_equal" , "assert_safe_sparse_allclose" ]
81
81
82
82
83
83
try :
@@ -394,41 +394,72 @@ def __exit__(self, *exc_info):
394
394
assert_greater = _assert_greater
395
395
396
396
397
- def _sparse_dense_allclose (val1 , val2 , rtol = 1e-7 , atol = 0 ):
397
+ if hasattr (np .testing , 'assert_allclose' ):
398
+ assert_allclose = np .testing .assert_allclose
399
+ else :
400
+ assert_allclose = _assert_allclose
401
+
402
+
403
+ def assert_safe_sparse_allclose (val1 , val2 , rtol = 1e-7 , atol = 0 , msg = None ):
398
404
"""Check if two objects are close up to the preset tolerance.
399
405
400
406
The objects can be scalars, lists, tuples, ndarrays or sparse matrices.
401
407
"""
402
- if isinstance (val1 , (int , float )) and isinstance (val2 , (int , float )):
403
- return np .allclose (float (val1 ), float (val2 ), rtol , atol )
408
+ if msg is None :
409
+ msg = ("The val1,\n %s\n and val2,\n %s\n are not all close"
410
+ % (val1 , val2 ))
411
+
412
+ if isinstance (val1 , str ) and isinstance (val2 , str ):
413
+ assert_true (val1 == val2 , msg = msg )
404
414
405
- if type (val1 ) is not type (val2 ):
406
- return False
415
+ elif np . isscalar (val1 ) and np . isscalar (val2 ):
416
+ assert_allclose ( val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
407
417
408
- comparables = (float , list , tuple , np .ndarray , sp .spmatrix )
418
+ # To allow mixed formats for sparse matrices alone
419
+ elif type (val1 ) is not type (val2 ) and not (
420
+ sp .issparse (val1 ) and sp .issparse (val2 )):
421
+ assert False , msg
409
422
410
- if not (isinstance (val1 , comparables ) or isinstance ( val2 , comparables )):
411
- raise ValueError ("The objects, %s and %s , are neither scalar nor "
423
+ elif not (isinstance (val1 , ( list , tuple , np . ndarray , sp . spmatrix , dict ) )):
424
+ raise ValueError ("The objects,\n %s \n and \n %s \n , are neither scalar nor "
412
425
"array-like." % (val1 , val2 ))
413
426
414
- # list/tuple (or list/tuple of ndarrays/spmatrices)
415
- if isinstance (val1 , (tuple , list )):
427
+ # list/tuple/dict (of list/tuple/dict...) of ndarrays/spmatrices/scalars
428
+ elif isinstance (val1 , (tuple , list , dict )):
429
+ if isinstance (val1 , dict ):
430
+ val1 , val2 = tuple (val1 .iteritems ()), tuple (val2 .iteritems ())
416
431
if (len (val1 ) == 0 ) and (len (val2 ) == 0 ):
417
- return True
418
- if len (val1 ) != len (val2 ):
419
- return False
420
- while isinstance (val1 [0 ], (tuple , list , np .ndarray , sp .spmatrix )):
421
- return all (_sparse_dense_allclose (val1_i , val2 [i ], rtol , atol )
422
- for i , val1_i in enumerate (val1 ))
423
- # Compare the lists, if they are not nested or singleton
424
- return np .allclose (val1 , val2 , rtol , atol )
425
-
426
- same_shape = val1 .shape == val2 .shape
427
- if sp .issparse (val1 ) or sp .issparse (val2 ):
428
- return same_shape and np .allclose (val1 .toarray (), val2 .toarray (),
429
- rtol , atol )
432
+ assert True
433
+ elif len (val1 ) != len (val2 ):
434
+ assert False , msg
435
+ # nested lists/tuples - [array([5, 6]), array([5, ])] and [[1, 3], ]
436
+ # Or ['str',] and ['str',]
437
+ elif isinstance (val1 [0 ], (tuple , list , np .ndarray , sp .spmatrix , str )):
438
+ # Compare them recursively
439
+ for i , val1_i in enumerate (val1 ):
440
+ assert_safe_sparse_allclose (val1_i , val2 [i ],
441
+ rtol = rtol , atol = atol , msg = msg )
442
+ # Compare the lists using np.allclose, if they are neither nested nor
443
+ # contain strings
444
+ else :
445
+ assert_allclose (val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
446
+
447
+ # scipy sparse matrix
448
+ elif sp .issparse (val1 ) or sp .issparse (val2 ):
449
+ # NOTE: ref np.allclose's note for assymetricity in this testing
450
+ if val1 .shape != val2 .shape :
451
+ assert False , msg
452
+
453
+ diff = abs (val1 - val2 ) - (rtol * abs (val2 ))
454
+ assert np .any (diff > atol ).size == 0 , msg
455
+
456
+ # numpy ndarray
457
+ elif isinstance (val1 , (np .ndarray )):
458
+ if val1 .shape != val2 .shape :
459
+ assert False , msg
460
+ assert_allclose (val1 , val2 , rtol = rtol , atol = atol , err_msg = msg )
430
461
else :
431
- return same_shape and np . allclose ( val1 , val2 , rtol , atol )
462
+ assert False , msg
432
463
433
464
434
465
def _assert_allclose (actual , desired , rtol = 1e-7 , atol = 0 ,
@@ -442,12 +473,6 @@ def _assert_allclose(actual, desired, rtol=1e-7, atol=0,
442
473
raise AssertionError (err_msg )
443
474
444
475
445
- if hasattr (np .testing , 'assert_allclose' ):
446
- assert_allclose = np .testing .assert_allclose
447
- else :
448
- assert_allclose = _assert_allclose
449
-
450
-
451
476
def assert_raise_message (exceptions , message , function , * args , ** kwargs ):
452
477
"""Helper function to test error messages in exceptions
453
478
@@ -495,12 +520,11 @@ def _assert_same_model_method(method, X, estimator1, estimator2, msg=None):
495
520
496
521
# Check if the method(X) returns the same for both models.
497
522
res1 , res2 = getattr (estimator1 , method )(X ), getattr (estimator2 , method )(X )
498
- if not _sparse_dense_allclose (res1 , res2 ):
499
- if msg is None :
500
- msg = ("Models are not equal. \n \n %s method returned different "
501
- "results:\n \n %s\n \n for :\n \n %s and\n \n %s\n \n for :\n \n %s."
502
- % (method , res1 , estimator1 , res2 , estimator2 ))
503
- raise AssertionError (msg )
523
+ if msg is None :
524
+ msg = ("Models are not equal. \n \n %s method returned different "
525
+ "results:\n \n %s\n \n for :\n \n %s and\n \n %s\n \n for :\n \n %s."
526
+ % (method , res1 , estimator1 , res2 , estimator2 ))
527
+ assert_safe_sparse_allclose (res1 , res2 , msg = msg )
504
528
505
529
506
530
def assert_same_model (X , estimator1 , estimator2 , msg = None ):
@@ -586,9 +610,8 @@ def assert_not_same_model(X, estimator1, estimator2, msg=None):
586
610
try :
587
611
assert_same_model (X , estimator1 , estimator2 )
588
612
except AssertionError :
589
- pass
590
- else :
591
- raise AssertionError (msg )
613
+ return
614
+ raise AssertionError (msg )
592
615
593
616
594
617
def assert_fitted_attributes_almost_equal (estimator1 , estimator2 , msg = None ):
@@ -623,23 +646,21 @@ def assert_fitted_attributes_almost_equal(estimator1, estimator2, msg=None):
623
646
"The attributes of both the estimators do not match." )
624
647
625
648
non_attributes = ("estimators_" , "estimator_" , "tree_" , "base_estimator_" ,
626
- "random_state_" )
649
+ "random_state_" , "root_" , "label_binarizer_" , "loss_" )
650
+ non_attr_suffixes = ("leaf_" ,)
651
+
627
652
for attr in est1_dict :
628
653
val1 , val2 = est1_dict [attr ], est2_dict [attr ]
629
654
630
655
# Consider keys that end in ``_`` only as attributes.
631
- if (attr .endswith ('_' ) and attr not in non_attributes ):
656
+ if (attr .endswith ('_' ) and attr not in non_attributes and
657
+ not attr .endswith (non_attr_suffixes )):
632
658
if msg is None :
633
659
msg = ("Attributes do not match. \n The attribute, %s, in "
634
660
"estimator1,\n \n %r\n \n is %r and in estimator2,"
635
661
"\n \n %r\n \n is %r.\n " ) % (attr , estimator1 , val1 ,
636
662
estimator2 , val2 )
637
- if isinstance (val1 , str ) and isinstance (val2 , str ):
638
- attr_similar = val1 == val2
639
- else :
640
- attr_similar = _sparse_dense_allclose (val1 , val2 )
641
- if not attr_similar :
642
- raise AssertionError (msg )
663
+ assert_safe_sparse_allclose (val1 , val2 , msg = msg )
643
664
644
665
645
666
def fake_mldata (columns_dict , dataname , matfile , ordering = None ):
0 commit comments