|
18 | 18 |
|
19 | 19 | import scipy as sp
|
20 | 20 | import scipy.io
|
| 21 | +import numpy as np |
| 22 | + |
21 | 23 | from functools import wraps
|
22 | 24 | try:
|
23 | 25 | # Python 2
|
|
45 | 47 | from numpy.testing import assert_array_equal
|
46 | 48 | from numpy.testing import assert_array_almost_equal
|
47 | 49 | from numpy.testing import assert_array_less
|
48 |
| -import numpy as np |
49 | 50 |
|
50 | 51 | from sklearn.base import (ClassifierMixin, RegressorMixin, TransformerMixin,
|
51 | 52 | ClusterMixin)
|
|
54 | 55 | "assert_raises_regexp", "raises", "with_setup", "assert_true",
|
55 | 56 | "assert_false", "assert_almost_equal", "assert_array_equal",
|
56 | 57 | "assert_array_almost_equal", "assert_array_less",
|
57 |
| - "assert_less", "assert_less_equal", |
58 |
| - "assert_greater", "assert_greater_equal"] |
| 58 | + "assert_less", "assert_less_equal", "assert_greater", |
| 59 | + "assert_greater_equal", |
| 60 | + "assert_same_model", "assert_fitted_attributes_equal"] |
59 | 61 |
|
60 | 62 |
|
61 | 63 | try:
|
@@ -422,6 +424,77 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs):
|
422 | 424 | (names, function.__name__))
|
423 | 425 |
|
424 | 426 |
|
| 427 | +def _assert_same_model_method(method, X, estimator1, estimator2, msg=None): |
| 428 | + if hasattr(estimator1, method): |
| 429 | + m = '%r\n\nhas %s, but\n\n%r\n\ndoes not' % (estimator1, |
| 430 | + method, |
| 431 | + estimator2) |
| 432 | + |
| 433 | + if not hasattr(estimator2, method): |
| 434 | + raise AttributeError(m) |
| 435 | + |
| 436 | + # Check if the method(X) returns the same for both models. |
| 437 | + res1 = getattr(estimator1, method)(X) |
| 438 | + res2 = getattr(estimator2, method)(X) |
| 439 | + same_model = (res1.shape == res2.shape) and np.allclose(res1, res2) |
| 440 | + |
| 441 | + if msg is None: |
| 442 | + msg = ("Models are not equal. \n\n%s method returned" |
| 443 | + " different results:\n\n%r\n\n and\n\n%r\n\n for :\n\n%r" |
| 444 | + "\n\nand :\n\n%r" % (method, res1, res2, |
| 445 | + estimator1, estimator2)) |
| 446 | + |
| 447 | + assert same_model, msg |
| 448 | + |
| 449 | + |
| 450 | +def assert_same_model(X, estimator1, estimator2, msg=None): |
| 451 | + """Helper function to check if models are same""" |
| 452 | + _assert_same_model_method('predict', X, estimator1, estimator2, msg) |
| 453 | + _assert_same_model_method('transform', X, estimator1, estimator2, msg) |
| 454 | + _assert_same_model_method('decision_function', |
| 455 | + X, estimator1, estimator2, msg) |
| 456 | + _assert_same_model_method('predict_proba', |
| 457 | + X, estimator1, estimator2, msg) |
| 458 | + |
| 459 | + assert_fitted_attributes_equal(estimator1, estimator2) |
| 460 | + |
| 461 | + |
| 462 | +def assert_not_same_model(X, estimator1, estimator2, msg=None): |
| 463 | + """Helper function to check if models are different""" |
| 464 | + try: |
| 465 | + assert_same_model(X, estimator1, estimator2) |
| 466 | + except AssertionError: |
| 467 | + pass |
| 468 | + else: |
| 469 | + raise AssertionError(msg) |
| 470 | + |
| 471 | + |
| 472 | +def assert_fitted_attributes_equal(estimator1, estimator2): |
| 473 | + """Helper function to check if fitted model attributes are equal.""" |
| 474 | + # A list of attributes which are known to be inconsistent. |
| 475 | + # FIXME embedding_ to be removed after #4299 |
| 476 | + skip_attributes = ('embedding_',) |
| 477 | + |
| 478 | + est1_dict = estimator1.__dict__.copy() |
| 479 | + est2_dict = estimator2.__dict__.copy() |
| 480 | + |
| 481 | + # Remove all keys that are non-attributes, are not comparable types |
| 482 | + # and those which are found to be inconsistent |
| 483 | + for attr, value1 in est1_dict.items(): |
| 484 | + if ((not attr.endswith('_')) or attr.endswith('estimator_') or |
| 485 | + attr.endswith('estimators_') or |
| 486 | + attr in skip_attributes): |
| 487 | + est1_dict.pop(attr) |
| 488 | + try: |
| 489 | + est2_dict.pop(attr) |
| 490 | + except KeyError: # Incase the attribute is not present in est2 |
| 491 | + pass |
| 492 | + |
| 493 | + # assert_equal is capable of recursively checking for all the items |
| 494 | + # of the two dicts |
| 495 | + assert_equal(est1_dict, est2_dict) |
| 496 | + |
| 497 | + |
425 | 498 | def fake_mldata(columns_dict, dataname, matfile, ordering=None):
|
426 | 499 | """Create a fake mldata data set.
|
427 | 500 |
|
@@ -628,8 +701,8 @@ def is_abstract(c):
|
628 | 701 | estimators = filtered_estimators
|
629 | 702 | if type_filter:
|
630 | 703 | raise ValueError("Parameter type_filter must be 'classifier', "
|
631 |
| - "'regressor', 'transformer', 'cluster' or None, got" |
632 |
| - " %s." % repr(type_filter)) |
| 704 | + "'regressor', 'transformer', 'cluster' or None," |
| 705 | + "got %s." % repr(type_filter)) |
633 | 706 |
|
634 | 707 | # drop duplicates, sort for reproducibility
|
635 | 708 | return sorted(set(estimators))
|
|
0 commit comments