|
1 | 1 | import warnings
|
2 | 2 | import unittest
|
3 | 3 | import sys
|
| 4 | +import numpy as np |
4 | 5 |
|
5 | 6 | from nose.tools import assert_raises
|
6 |
| - |
7 | 7 | from sklearn.utils.testing import (
|
8 | 8 | _assert_less,
|
9 | 9 | _assert_greater,
|
|
13 | 13 | assert_no_warnings,
|
14 | 14 | assert_equal,
|
15 | 15 | set_random_state,
|
16 |
| - assert_raise_message) |
17 |
| - |
| 16 | + assert_raise_message, |
| 17 | + assert_same_model, |
| 18 | + assert_not_same_model) |
18 | 19 | from sklearn.tree import DecisionTreeClassifier
|
19 | 20 | from sklearn.lda import LDA
|
| 21 | +from sklearn.qda import QDA |
| 22 | +from sklearn.datasets import make_blobs |
| 23 | +from sklearn.svm import LinearSVC |
| 24 | +from sklearn.cluster import KMeans |
20 | 25 |
|
21 | 26 | try:
|
22 | 27 | from nose.tools import assert_less
|
@@ -96,10 +101,53 @@ def _no_raise():
|
96 | 101 | "test", _no_raise)
|
97 | 102 |
|
98 | 103 |
|
| 104 | +def test_assert_same_not_same_model(): |
| 105 | + X1, y1 = make_blobs(n_samples=200, n_features=5, center_box=(-200, -150), |
| 106 | + centers=2, random_state=0) |
| 107 | + X2, y2 = make_blobs(n_samples=100, n_features=5, center_box=(-1, 1), |
| 108 | + centers=3, random_state=1) |
| 109 | + X3, y3 = make_blobs(n_samples=50, n_features=5, center_box=(-100, -50), |
| 110 | + centers=4, random_state=2) |
| 111 | + |
| 112 | + # Checking both non-transductive and transductive algorithms |
| 113 | + # By testing for transductive algorithms we also eventually test |
| 114 | + # the assert_fitted_attributes_equal helper. |
| 115 | + for Estimator in (LinearSVC, KMeans): |
| 116 | + assert_same_model(X3, Estimator(random_state=0).fit(X1, y1), |
| 117 | + Estimator(random_state=0).fit(X1, y1)) |
| 118 | + assert_raises(AssertionError, assert_not_same_model, X3, |
| 119 | + Estimator(random_state=0).fit(X1, y1), |
| 120 | + Estimator(random_state=0).fit(X1, y1)) |
| 121 | + assert_raises(AssertionError, assert_same_model, X3, |
| 122 | + Estimator(random_state=0).fit(X1, y1), |
| 123 | + Estimator(random_state=0).fit(X2, y2)) |
| 124 | + assert_not_same_model(X3, Estimator(random_state=0).fit(X1, y1), |
| 125 | + Estimator(random_state=0).fit(X2, y2)) |
| 126 | + |
| 127 | + |
| 128 | +def test_qda_same_model(): |
| 129 | + # NRT to make sure the rotations_ attribute is correctly compared |
| 130 | + X = np.array([[0, 0], [-2, -2], [-2, -1], [-1, -1], [-1, -2], |
| 131 | + [1, 3], [1, 2], [2, 1], [2, 2]]) |
| 132 | + y = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2]) |
| 133 | + X1 = np.array([[-3, -1], [-2, 0], [-1, 0], [-11, 0], [0, 0], [1, 0], |
| 134 | + [1, 5], [2, 0], [3, 4]]) |
| 135 | + y1 = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2]) |
| 136 | + X2 = np.array([[-1, -3], [0, -2], [0, -1], [0, -5], [0, 0], [10, 1], |
| 137 | + [0, 11], [0, 22], [0, 33]]) |
| 138 | + |
| 139 | + clf1 = QDA().fit(X, y) |
| 140 | + clf2 = QDA().fit(X, y) |
| 141 | + assert_same_model(X1, clf1, clf2) |
| 142 | + |
| 143 | + clf3 = QDA().fit(X1, y1) |
| 144 | + assert_not_same_model(X2, clf1, clf3) |
| 145 | + |
| 146 | + |
99 | 147 | # This class is inspired from numpy 1.7 with an alteration to check
|
100 | 148 | # the reset warning filters after calls to assert_warns.
|
101 | 149 | # This assert_warns behavior is specific to scikit-learn because
|
102 |
| -#`clean_warning_registry()` is called internally by assert_warns |
| 150 | +# `clean_warning_registry()` is called internally by assert_warns |
103 | 151 | # and clears all previous filters.
|
104 | 152 | class TestWarns(unittest.TestCase):
|
105 | 153 | def test_warn(self):
|
|
0 commit comments