8000 TST Add tests for the new assert helpers · scikit-learn/scikit-learn@3b8e8f7 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3b8e8f7

Browse files
committed
TST Add tests for the new assert helpers
1 parent 30ac135 commit 3b8e8f7

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

sklearn/utils/tests/test_testing.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import warnings
22
import unittest
33
import sys
4+
import numpy as np
45

56
from nose.tools import assert_raises
6-
77
from sklearn.utils.testing import (
88
_assert_less,
99
_assert_greater,
@@ -13,10 +13,15 @@
1313
assert_no_warnings,
1414
assert_equal,
1515
set_random_state,
16-
assert_raise_message)
17-
16+
assert_raise_message,
17+
assert_same_model,
18+
assert_not_same_model)
1819
from sklearn.tree import DecisionTreeClassifier
1920
from sklearn.lda import LDA
21+
from sklearn.qda import QDA
22+
from sklearn.datasets import make_blobs
23+
from sklearn.svm import LinearSVC
24+
from sklearn.cluster import KMeans
2025 8000

2126
try:
2227
from nose.tools import assert_less
@@ -96,10 +101,50 @@ def _no_raise():
96101
"test", _no_raise)
97102

98103

104+
def test_assert_same_not_same_model():
105+
X1, y1 = make_blobs(n_samples=200, n_features=5, center_box=(-200, -150),
106+
centers=2, random_state=0)
107+
X2, y2 = make_blobs(n_samples=100, n_features=5, center_box=(-1, 1),
108+
centers=3, random_state=1)
109+
X3, y3 = make_blobs(n_samples=50, n_features=5, center_box=(-100, -50),
110+
centers=4, random_state=2)
111+
112+
for Estimator in (LinearSVC, KMeans):
113+
assert_same_model(X3, Estimator(random_state=0).fit(X1, y1),
114+
Estimator(random_state=0).fit(X1, y1))
115+
116+
assert_raises(AssertionError, assert_not_same_model, X3,
117+
Estimator(random_state=0).fit(X1, y1),
118+
Estimator(random_state=0).fit(X1, y1))
119+
120+
assert_raises(AssertionError, assert_same_model, X3,
121+
Estimator().fit(X1, y1), Estimator().fit(X2, y2))
122+
123+
assert_not_same_model(X3, Estimator().fit(X1, y1),
124+
Estimator().fit(X2, y2))
125+
126+
127+
def test_qda_same_model():
128+
# NRT to make sure the rotations_ attribute is correctly compared
129+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
130+
y = np.array([1, 1, 1, 2, 2, 2])
131+
X2 = np.array([[-9, -9], [-20, -10], [-2, -2], [1, 1], [20, 10], [30, 20]])
132+
y2 = np.array([1, 1, 2, 2, 3, 3])
133+
134+
clf1 = QDA().fit(X, y)
135+
clf2 = QDA().fit(X, y)
136+
assert_same_model(X2, clf1, clf2)
137+
138+
X2 = np.array([[-9, -9], [-20, -10], [-2, -2], [1, 1], [20, 10], [30, 20]])
139+
y2 = np.array([1, 1, 2, 2, 3, 3])
140+
clf3 = QDA().fit(X2, y2)
141+
assert_not_same_model(X + X2, clf1, clf3)
142+
143+
99144
# This class is inspired from numpy 1.7 with an alteration to check
100145
# the reset warning filters after calls to assert_warns.
101146
# This assert_warns behavior is specific to scikit-learn because
102-
#`clean_warning_registry()` is called internally by assert_warns
147+
# `clean_warning_registry()` is called internally by assert_warns
103148
# and clears all previous filters.
104149
class TestWarns(unittest.TestCase):
105150
def test_warn(self):

0 commit comments

Comments
 (0)
0