-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] New assert helpers for model comparison and fit reset checks #4841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
ad18ddd
ENH/TST Add helpers assert_{same_model|fitted_attributes_equal}
raghavrv 64cefb8
TST Add test to check if estimators reset upon fit
raghavrv 61e98d3
FIX Shift the points instead of taking abs to preserve blobiness
raghavrv 35fdeaa
WIP + SCAFFOLD_REMOVE_BEFORE_MERGE
raghavrv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,11 @@ | |
import platform | ||
import struct | ||
|
||
import scipy as sp | ||
import scipy | ||
import scipy.io | ||
import scipy.sparse as sp | ||
import numpy as np | ||
|
||
from functools import wraps | ||
from operator import itemgetter | ||
try: | ||
|
@@ -71,10 +74,11 @@ | |
__all__ = ["assert_equal", "assert_not_equal", "assert_raises", | ||
"assert_raises_regexp", "raises", "with_setup", "assert_true", | ||
"assert_false", "assert_almost_equal", "assert_array_equal", | ||
"assert_array_almost_equal", "assert_array_less", | ||
"assert_less", "assert_less_equal", | ||
"assert_greater", "assert_greater_equal", | ||
"assert_approx_equal"] | ||
"assert_allclose", "assert_array_almost_equal", "assert_array_less", | ||
"assert_less", "assert_less_equal", "assert_greater", | ||
"assert_greater_equal", "assert_same_model", | ||
"assert_not_same_model", "assert_fitted_attributes_almost_equal", | ||
"assert_approx_equal", "assert_safe_sparse_allclose"] | ||
|
||
|
||
try: | ||
|
@@ -383,20 +387,83 @@ def __exit__(self, *exc_info): | |
assert_greater = _assert_greater | ||
|
||
|
||
if hasattr(np.testing, 'assert_allclose'): | ||
assert_allclose = np.testing.assert_allclose | ||
else: | ||
assert_allclose = _assert_allclose | ||
|
||
|
||
def assert_safe_sparse_allclose(val1, val2, rtol=1e-7, atol=0, msg=None): | ||
"""Check if two objects are close up to the preset tolerance. | ||
|
||
The objects can be scalars, lists, tuples, ndarrays or sparse matrices. | ||
""" | ||
if msg is None: | ||
msg = ("The val1,\n%s\nand val2,\n%s\nare not all close" | ||
% (val1, val2)) | ||
|
||
if isinstance(val1, str) and isinstance(val2, str): | ||
assert_true(val1 == val2, msg=msg) | ||
|
||
elif np.isscalar(val1) and np.isscalar(val2): | ||
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg) | ||
|
||
# To allow mixed formats for sparse matrices alone | ||
elif type(val1) is not type(val2) and not ( | ||
sp.issparse(val1) and sp.issparse(val2)): | ||
assert False, msg | ||
|
||
elif not (isinstance(val1, (list, tuple, np.ndarray, sp.spmatrix, dict))): | ||
raise ValueError("The objects,\n%s\nand\n%s\n, are neither scalar nor " | ||
"array-like." % (val1, val2)) | ||
|
||
# list/tuple/dict (of list/tuple/dict...) of ndarrays/spmatrices/scalars | ||
elif isinstance(val1, (tuple, list, dict)): | ||
if isinstance(val1, dict): | ||
val1, val2 = tuple(val1.iteritems()), tuple(val2.iteritems()) | ||
if (len(val1) == 0) and (len(val2) == 0): | ||
assert True | ||
elif len(val1) != len(val2): | ||
assert False, msg | ||
# nested lists/tuples - [array([5, 6]), array([5, ])] and [[1, 3], ] | ||
# Or ['str',] and ['str',] | ||
elif isinstance(val1[0], (tuple, list, np.ndarray, sp.spmatrix, str)): | ||
# Compare them recursively | ||
for i, val1_i in enumerate(val1): | ||
assert_safe_sparse_allclose(val1_i, val2[i], | ||
rtol=rtol, atol=atol, msg=msg) | ||
# Compare the lists using np.allclose, if they are neither nested nor | ||
# contain strings | ||
else: | ||
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg) | ||
|
||
# scipy sparse matrix | ||
elif sp.issparse(val1) or sp.issparse(val2): | ||
# NOTE: ref np.allclose's note for assymetricity in this testing | ||
if val1.shape != val2.shape: | ||
assert False, msg | ||
|
||
diff = abs(val1 - val2) - (rtol * abs(val2)) | ||
assert np.any(diff > atol).size == 0, msg | ||
|
||
# numpy ndarray | ||
elif isinstance(val1, (np.ndarray)): | ||
if val1.shape != val2.shape: | ||
assert False, msg | ||
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg) | ||
else: | ||
assert False, msg | ||
|
||
|
||
def _assert_allclose(actual, desired, rtol=1e-7, atol=0, | ||
err_msg='', verbose=True): | ||
actual, desired = np.asanyarray(actual), np.asanyarray(desired) | ||
if np.allclose(actual, desired, rtol=rtol, atol=atol): | ||
return | ||
msg = ('Array not equal to tolerance rtol=%g, atol=%g: ' | ||
'actual %s, desired %s') % (rtol, atol, actual, desired) | ||
raise AssertionError(msg) | ||
|
||
|
||
if hasattr(np.testing, 'assert_allclose'): | ||
assert_allclose = np.testing.assert_allclose | ||
else: | ||
assert_allclose = _assert_allclose | ||
if err_msg == '': | ||
err_msg = ('Array not equal to tolerance rtol=%g, atol=%g: ' | ||
'actual %s, desired %s') % (rtol, atol, actual, desired) | ||
raise AssertionError(err_msg) | ||
|
||
|
||
def assert_raise_message(exceptions, message, function, *args, **kwargs): | ||
|
@@ -433,6 +500,162 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs): | |
(names, function.__name__)) | ||
|
||
|
||
def _assert_same_model_method(method, X, estimator1, estimator2, msg=None): | ||
method_err = '%r\n\nhas %s, but\n\n%r\n\ndoes not.' | ||
# If the method is absent in only one model consider them different | ||
if hasattr(estimator1, method) and not hasattr(estimator2, method): | ||
raise AssertionError(method_err % (estimator1, method, estimator2)) | ||
if hasattr(estimator2, method) and not hasattr(estimator1, method): | ||
raise AssertionError(method_err % estimator2, method, estimator1) | ||
|
||
if not hasattr(estimator1, method): | ||
return | ||
|
||
# Check if the method(X) returns the same for both models. | ||
res1, res2 = getattr(estimator1, method)(X), getattr(estimator2, method)(X) | ||
if msg is None: | ||
msg = ("Models are not equal. \n\n%s method returned different " | ||
"results:\n\n%s\n\n for :\n\n%s and\n\n%s\n\n for :\n\n%s." | ||
% (method, res1, estimator1, res2, estimator2)) | ||
assert_safe_sparse_allclose(res1, res2, msg=msg) | ||
|
||
|
||
def assert_same_model(X, estimator1, estimator2, msg=None): | ||
"""Helper function to check if the models are similar. | ||
|
||
The check is done by comparing the outputs of the methods ``predict``, | ||
``transform``, ``decision_function`` and the ``predict_proba`` provided | ||
they exist in both the models. If any of those methods do not exist in | ||
one model alone, the models are considered different. | ||
|
||
If the outputs from both the models for each of the available above listed | ||
function(s) are similar, a comparison of the attributes of the models | ||
that end with ``_`` is done to ascertain the similarity of the model. | ||
|
||
If the models are different an AssertionError with the given error message | ||
is raised. | ||
|
||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
Input data, for the fitted models, used for comparing them. | ||
|
||
estimator1 : An estimator object. | ||
The first fitted model to be compared. | ||
|
||
estimator2 : An estimator object. | ||
The second fitted model to be compared. | ||
|
||
msg : str | ||
The error message to be used while raising the AssertionError if the | ||
models are similar. | ||
|
||
Notes | ||
----- | ||
This check is not exhaustive since all attributes of the model are assumed | ||
to end with ``_``. If that is not the case, it could lead to false | ||
positives. | ||
""" | ||
_assert_same_model_method('predict', X, estimator1, estimator2, msg) | ||
_assert_same_model_method('transform', X, estimator1, estimator2, msg) | ||
_assert_same_model_method('decision_function', | ||
X, estimator1, estimator2, msg) | ||
_assert_same_model_method('predict_proba', X, estimator1, estimator2, msg) | ||
assert_fitted_attributes_almost_equal(estimator1, estimator2) | ||
|
||
|
||
def assert_not_same_model(X, estimator1, estimator2, msg=None): | ||
"""Helper function to check if the models are different. | ||
|
||
The check is done by comparing the outputs of the methods ``predict``, | ||
``transform``, ``decision_function`` and the ``predict_proba``, provided | ||
they exist in both the models. If any of those methods do not exist in | ||
one model alone, the models are considered different. | ||
|
||
If the outputs from both the models for each of the available, above listed | ||
function(s) are similar, a comparison of the attributes of the models | ||
that end with ``_`` is done to ascertain the similarity of the model. | ||
|
||
If the models are similar an AssertionError with the given error message | ||
is raised. | ||
|
||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
Input data, for the fitted models, used for comparing them. | ||
|
||
estimator1 : An estimator object. | ||
The first fitted model to be compared. | ||
|
||
estimator2 : An estimator object. | ||
The second fitted model to be compared. | ||
|
||
msg : str | ||
The error message to be used while raising the AssertionError if the | ||
models are similar. | ||
|
||
Notes | ||
----- | ||
This check is not exhaustive since all attributes of the model are assumed | ||
to end with ``_``. If that is not the case, it could lead to false | ||
negatives. | ||
""" | ||
try: | ||
assert_same_model(X, estimator1, estimator2) | ||
except AssertionError: | ||
return | ||
raise AssertionError(msg) | ||
|
||
|
||
def assert_fitted_attributes_almost_equal(estimator1, estimator2, msg=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it hard to follow why this doesn't have an underscore but other asserters introduced here do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt this could also be useful for other tests? |
||
"""Helper function to check if the fitted model attributes are similar. | ||
|
||
This check is done by comparing the attributes from both the models that | ||
end in ``_``. | ||
|
||
If the fitted models attributes are different an AssertionError with the | ||
given error message is raised. | ||
|
||
Parameters | ||
---------- | ||
estimator1 : An estimator object. | ||
The first fitted model whose attributes are to be compared. | ||
|
||
estimator2 : An estimator object. | ||
The second fitted model whose attributes are to be compared. | ||
|
||
msg : str | ||
The error message to be used while raising the AssertionError, if the | ||
fitted models attributes are different. | ||
|
||
Notes | ||
----- | ||
This check is not exhaustive since all attributes of the model are assumed | ||
to end with ``_``. If that is not the case, it could lead to false | ||
positives. | ||
""" | ||
est1_dict, est2_dict = estimator1.__dict__, estimator2.__dict__ | ||
assert_array_equal(est1_dict.keys(), est2_dict.keys(), | ||
"The attributes of both the estimators do not match.") | ||
|
||
non_attributes = ("estimators_", "estimator_", "tree_", "base_estimator_", | ||
"random_state_", "root_", "label_binarizer_", "loss_") | ||
non_attr_suffixes = ("leaf_",) | ||
|
||
for attr in est1_dict: | ||
val1, val2 = est1_dict[attr], est2_dict[attr] | ||
|
||
# Consider keys that end in ``_`` only as attributes. | ||
if (attr.endswith('_') and attr not in non_attributes and | ||
not attr.endswith(non_attr_suffixes)): | ||
if msg is None: | ||
msg = ("Attributes do not match. \nThe attribute, %s, in " | ||
"estimator1,\n\n%r\n\n is %r and in estimator2," | ||
"\n\n%r\n\n is %r.\n") % (attr, estimator1, val1, | ||
estimator2, val2) | ||
assert_safe_sparse_allclose(val1, val2, msg=msg) | ||
|
||
|
||
def fake_mldata(columns_dict, dataname, matfile, ordering=None): | ||
"""Create a fake mldata data set. | ||
|
||
|
@@ -465,7 +688,7 @@ def fake_mldata(columns_dict, dataname, matfile, ordering=None): | |
ordering = sorted(list(datasets.keys())) | ||
# NOTE: setting up this array is tricky, because of the way Matlab | ||
# re-packages 1D arrays | ||
datasets['mldata_descr_ordering'] = sp.empty((1, len(ordering)), | ||
datasets['mldata_descr_ordering'] = np.empty((1, len(ordering)), | ||
dtype='object') | ||
for i, name in enumerate(ordering): | ||
datasets['mldata_descr_ordering'][0, i] = name | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vene @jnothman Could you look at this implementation once? (This is still WIP as 30% of the tests (fit reset tests) don't pass... but I'd like to know if I am going in the right direction)