8000 [WIP] New assert helpers for model comparison and fit reset checks by raghavrv · Pull Request #4841 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] New assert helpers for model comparison and fit reset checks #4841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from sklearn.utils.testing import SkipTest
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.testing import assert_warns

from sklearn.utils.testing import assert_same_model
from sklearn.utils.testing import assert_not_same_model

from sklearn.base import (clone, ClassifierMixin, RegressorMixin,
TransformerMixin, ClusterMixin, BaseEstimator)
Expand Down Expand Up @@ -76,6 +77,7 @@


def _yield_non_meta_checks(name, Estimator):
"""
yield check_estimators_dtypes
yield check_fit_score_takes_y
yield check_dtype_object
Expand Down Expand Up @@ -107,6 +109,9 @@ def _yield_non_meta_checks(name, Estimator):
# Test that estimators can be pickled, and once pickled
# give the same answer as before.
yield check_estimators_pickle
"""
if name not in ('SpectralEmbedding',):
yield check_estimator_fit_reset


def _yield_classifier_checks(name, Classifier):
Expand Down Expand Up @@ -199,6 +204,7 @@ def _yield_clustering_checks(name, Clusterer):
def _yield_all_checks(name, Estimator):
for check in _yield_non_meta_checks(name, Estimator):
yield check
"""
if issubclass(Estimator, ClassifierMixin):
for check in _yield_classifier_checks(name, Estimator):
yield check
Expand All @@ -217,6 +223,7 @@ def _yield_all_checks(name, Estimator):
yield check_fit2d_1feature
yield check_fit1d_1feature
yield check_fit1d_1sample
"""


def check_estimator(Estimator):
Expand Down Expand Up @@ -1553,3 +1560,45 @@ def check_classifiers_regression_target(name, Estimator):
e = Estimator()
msg = 'Unknown label type: '
assert_raises_regex(ValueError, msg, e.fit, X, y)


@ignore_warnings
def check_estimator_fit_reset(name, Estimator):
X1, y1 = make_blobs(n_samples=50, n_features=2, center_box=(-200, -150),
centers=2, random_state=0)
X2, y2 = make_blobs(n_samples=50, n_features=2, center_box=(200, 150),
centers=2, random_state=1)
X3, y3 = make_blobs(n_samples=50, n_features=2, center_box=(-200, 150),
centers=3, random_state=2)
X4, y4 = make_blobs(n_samples=50, n_features=5, center_box=(-200, -150),
centers=2, random_state=0)
X5, y5 = make_blobs(n_samples=50, n_features=5, center_box=(200, 150),
centers=2, random_state=1)
X6, y6 = make_blobs(n_samples=50, n_features=5, center_box=(-200, 150),
centers=3, random_state=2)

# Some estimators work only on non-negative inputs
if name in ('AdditiveChi2Sampler', 'SkewedChi2Sampler', 'NMF',
'MultinomialNB', 'ProjectedGradientNMF',):
X1, X2, X3, X4, X5, X6 = map(lambda X: X - X.min(),
(X1, X2, X3, X4, X5, X6))

y1, y2, y3, y4, y5, y6 = map(multioutput_estimator_convert_y_2d,
(name,)*6, (y1, y2, y3, y4, y5, y6))
estimator_1 = Estimator()
estimator_2 = Estimator()

set_testing_parameters(estimator_1)
set_testing_parameters(estimator_2)

set_random_state(estimator_1)
set_random_state(estimator_2)

assert_not_same_model(X3, estimator_1.fit(X1, y1), estimator_2.fit(X2, y2))
assert_same_model(X3, estimator_1.fit(X2, y2), estimator_2)
assert_same_model(X2, estimator_1.fit(X1, y1), estimator_2.fit(X1, y1))

# Fitting new data with 5 features
assert_not_same_model(X6, estimator_1.fit(X4, y4), estimator_2.fit(X5, y5))
assert_same_model(X6, estimator_1.fit(X5, y5), estimator_2)
assert_same_model(X5, estimator_1.fit(X4, y4), estimator_2.fit(X4, y4))
253 changes: 238 additions & 15 deletions sklearn/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
import platform
import struct

import scipy as sp
import scipy
import scipy.io
import scipy.sparse as sp
import numpy as np

from functools import wraps
from operator import itemgetter
try:
Expand Down Expand Up @@ -71,10 +74,11 @@
__all__ = ["assert_equal", "assert_not_equal", "assert_raises",
"assert_raises_regexp", "raises", "with_setup", "assert_true",
"assert_false", "assert_almost_equal", "assert_array_equal",
"assert_array_almost_equal", "assert_array_less",
"assert_less", "assert_less_equal",
"assert_greater", "assert_greater_equal",
"assert_approx_equal"]
"assert_allclose", "assert_array_almost_equal", "assert_array_less",
"assert_less", "assert_less_equal", "assert_greater",
"assert_greater_equal", "assert_same_model",
"assert_not_same_model", "assert_fitted_attributes_almost_equal",
"assert_approx_equal", "assert_safe_sparse_allclose"]


try:
Expand Down Expand Up @@ -383,20 +387,83 @@ def __exit__(self, *exc_info):
assert_greater = _assert_greater


if hasattr(np.testing, 'assert_allclose'):
assert_allclose = np.testing.assert_allclose
else:
assert_allclose = _assert_allclose


def assert_safe_sparse_allclose(val1, val2, rtol=1e-7, atol=0, msg=None):
"""Check if two objects are close up to the preset tolerance.

The objects can be scalars, lists, tuples, ndarrays or sparse matrices.
"""
if msg is None:
msg = ("The val1,\n%s\nand val2,\n%s\nare not all close"
% (val1, val2))

if isinstance(val1, str) and isinstance(val2, str):
assert_true(val1 == val2, msg=msg)

elif np.isscalar(val1) and np.isscalar(val2):
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg)

# To allow mixed formats for sparse matrices alone
elif type(val1) is not type(val2) and not (
sp.issparse(val1) and sp.issparse(val2)):
assert False, msg

elif not (isinstance(val1, (list, tuple, np.ndarray, sp.spmatrix, dict))):
raise ValueError("The objects,\n%s\nand\n%s\n, are neither scalar nor "
"array-like." % (val1, val2))

# list/tuple/dict (of list/tuple/dict...) of ndarrays/spmatrices/scalars
elif isinstance(val1, (tuple, list, dict)):
if isinstance(val1, dict):
val1, val2 = tuple(val1.iteritems()), tuple(val2.iteritems())
if (len(val1) == 0) and (len(val2) == 0):
assert True
elif len(val1) != len(val2):
assert False, msg
# nested lists/tuples - [array([5, 6]), array([5, ])] and [[1, 3], ]
# Or ['str',] and ['str',]
elif isinstance(val1[0], (tuple, list, np.ndarray, sp.spmatrix, str)):
# Compare them recursively
for i, val1_i in enumerate(val1):
assert_safe_sparse_allclose(val1_i, val2[i],
rtol=rtol, atol=atol, msg=msg)
# Compare the lists using np.allclose, if they are neither nested nor
# contain strings
else:
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg)

# scipy sparse matrix
elif sp.issparse(val1) or sp.issparse(val2):
# NOTE: ref np.allclose's note for assymetricity in this testing
if val1.shape != val2.shape:
assert False, msg

diff = abs(val1 - val2) - (rtol * abs(val2))
assert np.any(diff > atol).size == 0, msg

# numpy ndarray
elif isinstance(val1, (np.ndarray)):
if val1.shape != val2.shape:
assert False, msg
assert_allclose(val1, val2, rtol=rtol, atol=atol, err_msg=msg)
else:
assert False, msg
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vene @jnothman Could you look at this implementation once? (This is still WIP as 30% of the tests (fit reset tests) don't pass... but I'd like to know if I am going in the right direction)



def _assert_allclose(actual, desired, rtol=1e-7, atol=0,
err_msg='', verbose=True):
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
if np.allclose(actual, desired, rtol=rtol, atol=atol):
return
msg = ('Array not equal to tolerance rtol=%g, atol=%g: '
'actual %s, desired %s') % (rtol, atol, actual, desired)
raise AssertionError(msg)


if hasattr(np.testing, 'assert_allclose'):
assert_allclose = np.testing.assert_allclose
else:
assert_allclose = _assert_allclose
if err_msg == '':
err_msg = ('Array not equal to tolerance rtol=%g, atol=%g: '
'actual %s, desired %s') % (rtol, atol, actual, desired)
raise AssertionError(err_msg)


def assert_raise_message(exceptions, message, function, *args, **kwargs):
Expand Down Expand Up @@ -433,6 +500,162 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs):
(names, function.__name__))


def _assert_same_model_method(method, X, estimator1, estimator2, msg=None):
method_err = '%r\n\nhas %s, but\n\n%r\n\ndoes not.'
# If the method is absent in only one model consider them different
if hasattr(estimator1, method) and not hasattr(estimator2, method):
raise AssertionError(method_err % (estimator1, method, estimator2))
if hasattr(estimator2, method) and not hasattr(estimator1, method):
raise AssertionError(method_err % estimator2, method, estimator1)

if not hasattr(estimator1, method):
return

# Check if the method(X) returns the same for both models.
res1, res2 = getattr(estimator1, method)(X), getattr(estimator2, method)(X)
if msg is None:
msg = ("Models are not equal. \n\n%s method returned different "
"results:\n\n%s\n\n for :\n\n%s and\n\n%s\n\n for :\n\n%s."
% (method, res1, estimator1, res2, estimator2))
assert_safe_sparse_allclose(res1, res2, msg=msg)


def assert_same_model(X, estimator1, estimator2, msg=None):
"""Helper function to check if the models are similar.

The check is done by comparing the outputs of the methods ``predict``,
``transform``, ``decision_function`` and the ``predict_proba`` provided
they exist in both the models. If any of those methods do not exist in
one model alone, the models are considered different.

If the outputs from both the models for each of the available above listed
function(s) are similar, a comparison of the attributes of the models
that end with ``_`` is done to ascertain the similarity of the model.

If the models are different an AssertionError with the given error message
is raised.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data, for the fitted models, used for comparing them.

estimator1 : An estimator object.
The first fitted model to be compared.

estimator2 : An estimator object.
The second fitted model to be compared.

msg : str
The error message to be used while raising the AssertionError if the
models are similar.

Notes
-----
This check is not exhaustive since all attributes of the model are assumed
to end with ``_``. If that is not the case, it could lead to false
positives.
"""
_assert_same_model_method('predict', X, estimator1, estimator2, msg)
_assert_same_model_method('transform', X, estimator1, estimator2, msg)
_assert_same_model_method('decision_function',
X, estimator1, estimator2, msg)
_assert_same_model_method('predict_proba', X, estimator1, estimator2, msg)
assert_fitted_attributes_almost_equal(estimator1, estimator2)


def assert_not_same_model(X, estimator1, estimator2, msg=None):
"""Helper function to check if the models are different.

The check is done by comparing the outputs of the methods ``predict``,
``transform``, ``decision_function`` and the ``predict_proba``, provided
they exist in both the models. If any of those methods do not exist in
one model alone, the models are considered different.

If the outputs from both the models for each of the available, above listed
function(s) are similar, a comparison of the attributes of the models
that end with ``_`` is done to ascertain the similarity of the model.

If the models are similar an AssertionError with the given error message
is raised.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data, for the fitted models, used for comparing them.

estimator1 : An estimator object.
The first fitted model to be compared.

estimator2 : An estimator object.
The second fitted model to be compared.

msg : str
The error message to be used while raising the AssertionError if the
models are similar.

Notes
-----
This check is not exhaustive since all attributes of the model are assumed
to end with ``_``. If that is not the case, it could lead to false
negatives.
"""
try:
assert_same_model(X, estimator1, estimator2)
except AssertionError:
return
raise AssertionError(msg)


def assert_fitted_attributes_almost_equal(estimator1, estimator2, msg=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it hard to follow why this doesn't have an underscore but other asserters introduced here do.

100E9 Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt this could also be useful for other tests?

"""Helper function to check if the fitted model attributes are similar.

This check is done by comparing the attributes from both the models that
end in ``_``.

If the fitted models attributes are different an AssertionError with the
given error message is raised.

Parameters
----------
estimator1 : An estimator object.
The first fitted model whose attributes are to be compared.

estimator2 : An estimator object.
The second fitted model whose attributes are to be compared.

msg : str
The error message to be used while raising the AssertionError, if the
fitted models attributes are different.

Notes
-----
This check is not exhaustive since all attributes of the model are assumed
to end with ``_``. If that is not the case, it could lead to false
positives.
"""
est1_dict, est2_dict = estimator1.__dict__, estimator2.__dict__
assert_array_equal(est1_dict.keys(), est2_dict.keys(),
"The attributes of both the estimators do not match.")

non_attributes = ("estimators_", "estimator_", "tree_", "base_estimator_",
"random_state_", "root_", "label_binarizer_", "loss_")
non_attr_suffixes = ("leaf_",)

for attr in est1_dict:
val1, val2 = est1_dict[attr], est2_dict[attr]

# Consider keys that end in ``_`` only as attributes.
if (attr.endswith('_') and attr not in non_attributes and
not attr.endswith(non_attr_suffixes)):
if msg is None:
msg = ("Attributes do not match. \nThe attribute, %s, in "
"estimator1,\n\n%r\n\n is %r and in estimator2,"
"\n\n%r\n\n is %r.\n") % (attr, estimator1, val1,
estimator2, val2)
assert_safe_sparse_allclose(val1, val2, msg=msg)


def fake_mldata(columns_dict, dataname, matfile, ordering=None):
"""Create a fake mldata data set.

Expand Down Expand Up @@ -465,7 +688,7 @@ def fake_mldata(columns_dict, dataname, matfile, ordering=None):
ordering = sorted(list(datasets.keys()))
# NOTE: setting up this array is tricky, because of the way Matlab
# re-packages 1D arrays
datasets['mldata_descr_ordering'] = sp.empty((1, len(ordering)),
datasets['mldata_descr_ordering'] = np.empty((1, len(ordering)),
dtype='object')
for i, name in enumerate(ordering):
datasets['mldata_descr_ordering'][0, i] = name
Expand Down
Loading
0