8000 TST Add tests for the new assert helpers · scikit-learn/scikit-learn@3341afb · GitHub
[go: up one dir, main page]

Skip to content

Commit 3341afb

Browse files
committed
TST Add tests for the new assert helpers
1 parent ee600c9 commit 3341afb

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

sklearn/utils/tests/test_testing.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import warnings
22
import unittest
33
import sys
4+
import numpy as np
45

56
from nose.tools import assert_raises
6-
77
from sklearn.utils.testing import (
88
_assert_less,
99
_assert_greater,
@@ -13,10 +13,15 @@
1313
assert_no_warnings,
1414
assert_equal,
1515
set_random_state,
16-
assert_raise_message)
17-
16+
assert_raise_message,
17+
assert_same_model,
18+
assert_not_same_model)
1819
from sklearn.tree import DecisionTreeClassifier
1920
from sklearn.lda import LDA
21+
from sklearn.qda import QDA
22+
from sklearn.datasets import make_blobs
23+
from sklearn.svm import LinearSVC
24+
from sklearn.cluster import KMeans
2025

2126
try:
2227
from nose.tools import assert_less
@@ -96,10 +101,53 @@ def _no_raise():
96101
"test", _no_raise)
97102

98103

104+
def test_assert_same_not_same_model():
105+
X1, y1 = make_blobs(n_samples=200, n_features=5, center_box=(-200, -150),
106+
centers=2, random_state=0)
107+
X2, y2 = make_blobs(n_samples=100, n_features=5, center_box=(-1, 1),
108+
centers=3, random_state=1)
109+
X3, y3 = make_blobs(n_samples=50, n_features=5, center_box=(-100, -50),
110+
centers=4, random_state=2)
111+
112+
# Checking both non-transductive and transductive algorithms
113+
# By testing for transductive algorithms we also eventually test
114+
# the assert_fitted_attributes_equal helper.
115+
for Estimator in (LinearSVC, KMeans):
116+
assert_same_model(X3, Estimator(random_state=0).fit(X1, y1),
117+
Estimator(random_state=0).fit(X1, y1))
118+
assert_raises(AssertionError, assert_not_same_model, X3,
119+
Estimator(random_state=0).fit(X1, y1),
120+
Estimator(random_state=0).fit(X1, y1))
121+
assert_raises(AssertionError, assert_same_model, X3,
122+
Estimator(random_state=0).fit(X1, y1),
123+
Estimator(random_state=0).fit(X2, y2))
124+
assert_not_same_model(X3, Estimator(random_state=0).fit(X1, y1),
125+
Estimator(random_state=0).fit(X2, y2))
126+
127+
128+
def test_qda_same_model():
129+
# NRT to make sure the rotations_ attribute is correctly compared
130+
X = np.array([[0, 0], [-2, -2], [-2, -1], [-1, -1], [-1, -2],
131+
[1, 3], [1, 2], [2, 1], [2, 2]])
132+
y = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2])
133+
X1 = np.array([[-3, -1], [-2, 0], [-1, 0], [-11, 0], [0, 0], [1, 0],
134+
[1, 5], [2, 0], [3, 4]])
135+
y1 = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2])
136+
X2 = np.array([[-1, -3], [0, -2], [0, -1], [0, -5], [0, 0], [10, 1],
137+
[0, 11], [0, 22], [0, 33]])
138+
139+
clf1 = QDA().fit(X, y)
140+
clf2 = QDA().fit(X, y)
141+
assert_same_model(X1, clf1, clf2)
142+
143+
clf3 = QDA().fit(X1, y1)
144+
assert_not_same_model(X2, clf1, clf3)
145+
146+
99147
# This class is inspired from numpy 1.7 with an alteration to check
100148
# the reset warning filters after calls to assert_warns.
101149
# This assert_warns behavior is specific to scikit-learn because
102-
#`clean_warning_registry()` is called internally by assert_warns
150+
# `clean_warning_registry()` is called internally by assert_warns
103151
# and clears all previous filters.
104152
class TestWarns(unittest.TestCase):
105153
def test_warn(self):

0 commit comments

Comments
 (0)
0