8000 ENH/TST Add helpers assert_{same_model|fitted_attributes_equal} · scikit-learn/scikit-learn@30ac135 · GitHub
[go: up one dir, main page]

Skip to content

Commit 30ac135

Browse files
committed
ENH/TST Add helpers assert_{same_model|fitted_attributes_equal}
1 parent 32b2f8e commit 30ac135

File tree

1 file changed

+78
-5
lines changed

1 file changed

+78
-5
lines changed

sklearn/utils/testing.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import scipy as sp
2020
import scipy.io
21+
import numpy as np
22+
2123
from functools import wraps
2224
try:
2325
# Python 2
@@ -45,7 +47,6 @@
4547
from numpy.testing import assert_array_equal
4648
from numpy.testing import assert_array_almost_equal
4749
from numpy.testing import assert_array_less
48-
import numpy as np
4950

5051
from sklearn.base import (ClassifierMixin, RegressorMixin, TransformerMixin,
5152
ClusterMixin)
@@ -54,8 +55,9 @@
5455
"assert_raises_regexp", "raises", "with_setup", "assert_true",
5556
"assert_false", "assert_almost_equal", "assert_array_equal",
5657
"assert_array_almost_equal", "assert_array_less",
57-
"assert_less", "assert_less_equal",
58-
"assert_greater", "assert_greater_equal"]
58+
"assert_less", "assert_less_equal", "assert_greater",
59+
"assert_greater_equal",
60+
"assert_same_model", "assert_fitted_attributes_equal"]
5961

6062

6163
try:
@@ -422,6 +424,77 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs):
422424
(names, function.__name__))
423425

424426

427+
def _assert_same_model_method(method, X, estimator1, estimator2, msg=None):
428+
if hasattr(estimator1, method):
429+
m = '%r\n\nhas %s, but\n\n%r\n\ndoes not' % (estimator1,
430+
method,
431+
estimator2)
432+
433+
if not hasattr(estimator2, method):
434+
raise AttributeError(m)
435+
436+
# Check if the method(X) returns the same for both models.
437+
res1 = getattr(estimator1, method)(X)
438+
res2 = getattr(estimator2, method)(X)
439+
same_model = (res1.shape == res2.shape) and np.allclose(res1, res2)
440+
441+
if msg is None:
442+
msg = ("Models are not equal. \n\n%s method returned"
443+
" different results:\n\n%r\n\n and\n\n%r\n\n for :\n\n%r"
444+
"\n\nand :\n\n%r" % (method, res1, res2,
445+
estimator1, estimator2))
446+
447+
assert same_model, msg
448+
449+
450+
def assert_same_model(X, estimator1, estimator2, msg=None):
451+
"""Helper function to check if models are same"""
452+
_assert_same_model_method('predict', X, estimator1, estimator2, msg)
453+
_assert_same_model_method('transform', X, estimator1, estimator2, msg)
454+
_assert_same_model_method('decision_function',
455+
X, estimator1, estimator2, msg)
456+
_assert_same_model_method('predict_proba',
457+
X, estimator1, estimator2, msg)
458+
459+
assert_fitted_attributes_equal(estimator1, estimator2)
460+
461+
462+
def assert_not_same_model(X, estimator1, estimator2, msg=None):
463+
"""Helper function to check if models are different"""
464+
try:
465+
assert_same_model(X, estimator1, estimator2)
466+
except AssertionError:
467+
pass
468+
else:
469+
raise AssertionError(msg)
470+
471+
472+
def assert_fitted_attributes_equal(estimator1, estimator2):
473+
"""Helper function to check if fitted model attributes are equal."""
474+
# A list of attributes which are known to be inconsistent.
475+
# FIXME embedding_ to be removed after #4299
476+
skip_attributes = ('embedding_',)
477+
478+
est1_dict = estimator1.__dict__.copy()
479+
est2_dict = estimator2.__dict__.copy()
480+
481+
# Remove all keys that are non-attributes, are not comparable types
482+
# and those which are found to be inconsistent
483+
for attr, value1 in est1_dict.items():
484+
if ((not attr.endswith('_')) or attr.endswith('estimator_') or
485+
attr.endswith('estimators_') or
486+
attr in skip_attributes):
487+
est1_dict.pop(attr)
488+
try:
489+
est2_dict.pop(attr)
490+
except KeyError: # Incase the attribute is not present in est2
491+
pass
492+
493+
# assert_equal is capable of recursively checking for all the items
494+
# of the two dicts
495+
assert_equal(est1_dict, est2_dict)
496+
497+
425498
def fake_mldata(columns_dict, dataname, matfile, ordering=None):
426499
"""Create a fake mldata data set.
427500
@@ -628,8 +701,8 @@ def is_abstract(c):
628701
estimators = filtered_estimators
629702
if type_filter:
630703
raise ValueError("Parameter type_filter must be 'classifier', "
631-
"'regressor', 'transformer', 'cluster' or None, got"
632-
" %s." % repr(type_filter))
704+
"'regressor', 'transformer', 'cluster' or None,"
705+
"got %s." % repr(type_filter))
633706

634707
# drop duplicates, sort for reproducibility
635708
return sorted(set(estimators))

0 commit comments

Comments
 (0)
0