|
1 | 1 | import warnings
|
2 | 2 | import unittest
|
3 | 3 | import sys
|
| 4 | +import numpy as np |
4 | 5 |
|
5 | 6 | from nose.tools import assert_raises
|
6 |
| - |
7 | 7 | from sklearn.utils.testing import (
|
8 | 8 | _assert_less,
|
9 | 9 | _assert_greater,
|
|
13 | 13 | assert_no_warnings,
|
14 | 14 | assert_equal,
|
15 | 15 | set_random_state,
|
16 |
| - assert_raise_message) |
17 |
| - |
| 16 | + assert_raise_message, |
| 17 | + assert_same_model, |
| 18 | + assert_not_same_model) |
18 | 19 | from sklearn.tree import DecisionTreeClassifier
|
19 | 20 | from sklearn.lda import LDA
|
| 21 | +from sklearn.qda import QDA |
| 22 | +from sklearn.datasets import make_blobs |
| 23 | +from sklearn.svm import LinearSVC |
| 24 | +from sklearn.cluster import KMeans |
20 | 25 |
8000
|
21 | 26 | try:
|
22 | 27 | from nose.tools import assert_less
|
@@ -96,10 +101,50 @@ def _no_raise():
|
96 | 101 | "test", _no_raise)
|
97 | 102 |
|
98 | 103 |
|
| 104 | +def test_assert_same_not_same_model(): |
| 105 | + X1, y1 = make_blobs(n_samples=200, n_features=5, center_box=(-200, -150), |
| 106 | + centers=2, random_state=0) |
| 107 | + X2, y2 = make_blobs(n_samples=100, n_features=5, center_box=(-1, 1), |
| 108 | + centers=3, random_state=1) |
| 109 | + X3, y3 = make_blobs(n_samples=50, n_features=5, center_box=(-100, -50), |
| 110 | + centers=4, random_state=2) |
| 111 | + |
| 112 | + for Estimator in (LinearSVC, KMeans): |
| 113 | + assert_same_model(X3, Estimator(random_state=0).fit(X1, y1), |
| 114 | + Estimator(random_state=0).fit(X1, y1)) |
| 115 | + |
| 116 | + assert_raises(AssertionError, assert_not_same_model, X3, |
| 117 | + Estimator(random_state=0).fit(X1, y1), |
| 118 | + Estimator(random_state=0).fit(X1, y1)) |
| 119 | + |
| 120 | + assert_raises(AssertionError, assert_same_model, X3, |
| 121 | + Estimator().fit(X1, y1), Estimator().fit(X2, y2)) |
| 122 | + |
| 123 | + assert_not_same_model(X3, Estimator().fit(X1, y1), |
| 124 | + Estimator().fit(X2, y2)) |
| 125 | + |
| 126 | + |
| 127 | +def test_qda_same_model(): |
| 128 | + # NRT to make sure the rotations_ attribute is correctly compared |
| 129 | + X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) |
| 130 | + y = np.array([1, 1, 1, 2, 2, 2]) |
| 131 | + X2 = np.array([[-9, -9], [-20, -10], [-2, -2], [1, 1], [20, 10], [30, 20]]) |
| 132 | + y2 = np.array([1, 1, 2, 2, 3, 3]) |
| 133 | + |
| 134 | + clf1 = QDA().fit(X, y) |
| 135 | + clf2 = QDA().fit(X, y) |
| 136 | + assert_same_model(X2, clf1, clf2) |
| 137 | + |
| 138 | + X2 = np.array([[-9, -9], [-20, -10], [-2, -2], [1, 1], [20, 10], [30, 20]]) |
| 139 | + y2 = np.array([1, 1, 2, 2, 3, 3]) |
| 140 | + clf3 = QDA().fit(X2, y2) |
| 141 | + assert_not_same_model(X + X2, clf1, clf3) |
| 142 | + |
| 143 | + |
99 | 144 | # This class is inspired from numpy 1.7 with an alteration to check
|
100 | 145 | # the reset warning filters after calls to assert_warns.
|
101 | 146 | # This assert_warns behavior is specific to scikit-learn because
|
102 |
| -#`clean_warning_registry()` is called internally by assert_warns |
| 147 | +# `clean_warning_registry()` is called internally by assert_warns |
103 | 148 | # and clears all previous filters.
|
104 | 149 | class TestWarns(unittest.TestCase):
|
105 | 150 | def test_warn(self):
|
|
0 commit comments