|
12 | 12 | from sklearn.utils.testing import assert_equal
|
13 | 13 | from sklearn.utils.testing import assert_array_equal
|
14 | 14 | from sklearn.utils.testing import assert_raises
|
| 15 | +from sklearn.utils.testing import assert_in |
| 16 | +from sklearn.utils.testing import assert_not_in |
15 | 17 | from sklearn.cluster.dbscan_ import DBSCAN
|
16 | 18 | from sklearn.cluster.dbscan_ import dbscan
|
17 | 19 | from sklearn.cluster.tests.common import generate_clustered_data
|
@@ -185,33 +187,44 @@ def test_pickle():
|
185 | 187 | assert_equal(type(pickle.loads(s)), obj.__class__)
|
186 | 188 |
|
187 | 189 |
|
| 190 | +def test_boundaries(): |
| 191 | + # ensure min_samples is inclusive of core point |
| 192 | + core, _ = dbscan([[0], [1]], eps=2, min_samples=2) |
| 193 | + assert_in(0, core) |
| 194 | + # ensure eps is inclusive of circumference |
| 195 | + core, _ = dbscan([[0], [1], [1]], eps=1, min_samples=2) |
| 196 | + assert_in(0, core) |
| 197 | + core, _ = dbscan([[0], [1], [1]], eps=.99, min_samples=2) |
| 198 | + assert_not_in(0, core) |
| 199 | + |
| 200 | + |
188 | 201 | def test_weighted_dbscan():
|
189 | 202 | # ensure sample_weight is validated
|
190 | 203 | assert_raises(ValueError, dbscan, [[0], [1]], sample_weight=[2])
|
191 | 204 | assert_raises(ValueError, dbscan, [[0], [1]], sample_weight=[2, 3, 4])
|
192 | 205 |
|
193 | 206 | # ensure sample_weight has an effect
|
194 | 207 | assert_array_equal([], dbscan([[0], [1]], sample_weight=None,
|
195 |
| - min_samples=5)[0]) |
| 208 | + min_samples=6)[0]) |
196 | 209 | assert_array_equal([], dbscan([[0], [1]], sample_weight=[5, 5],
|
197 |
| - min_samples=5)[0]) |
| 210 | + min_samples=6)[0]) |
198 | 211 | assert_array_equal([0], dbscan([[0], [1]], sample_weight=[6, 5],
|
199 |
| - min_samples=5)[0]) |
| 212 | + min_samples=6)[0]) |
200 | 213 | assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[6, 6],
|
201 |
| - min_samples=5)[0]) |
| 214 | + min_samples=6)[0]) |
202 | 215 |
|
203 | 216 | # points within eps of each other:
|
204 | 217 | assert_array_equal([0, 1], dbscan([[0], [1]], eps=1.5,
|
205 |
| - sample_weight=[5, 1], min_samples=5)[0]) |
| 218 | + sample_weight=[5, 1], min_samples=6)[0]) |
206 | 219 | # and effect of non-positive and non-integer sample_weight:
|
207 | 220 | assert_array_equal([], dbscan([[0], [1]], sample_weight=[5, 0],
|
208 |
| - eps=1.5, min_samples=5)[0]) |
209 |
| - assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[5, 0.1], |
210 |
| - eps=1.5, min_samples=5)[0]) |
| 221 | + eps=1.5, min_samples=6)[0]) |
| 222 | + assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[5.9, 0.1], |
| 223 | + eps=1.5, min_samples=6)[0]) |
211 | 224 | assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[6, 0],
|
212 |
| - eps=1.5, min_samples=5)[0]) |
| 225 | + eps=1.5, min_samples=6)[0]) |
213 | 226 | assert_array_equal([], dbscan([[0], [1]], sample_weight=[6, -1],
|
214 |
| - eps=1.5, min_samples=5)[0]) |
| 227 | + eps=1.5, min_samples=6)[0]) |
215 | 228 |
|
216 | 229 | # for non-negative sample_weight, cores should be identical to repetition
|
217 | 230 | rng = np.random.RandomState(42)
|
|
0 commit comments