8000 [MRG] MNT rename min_cluster_size_ratio to min_cluster_size (#11913) · scikit-learn/scikit-learn@73e034e · GitHub
[go: up one dir, main page]

Skip to content

Commit 73e034e

Browse files
committed
[MRG] MNT rename min_cluster_size_ratio to min_cluster_size (#11913)
1 parent 09edcfd commit 73e034e

File tree

2 files changed

+77
-58
lines changed

2 files changed

+77
-58
lines changed

sklearn/cluster/optics_.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def optics(X, min_samples=5, max_eps=np.inf, metric='euclidean',
2525
p=2, metric_params=None, maxima_ratio=.75,
2626
rejection_ratio=.7, similarity_threshold=0.4,
27-
significant_min=.003, min_cluster_size_ratio=.005,
27+
significant_min=.003, min_cluster_size=.005,
2828
min_maxima_ratio=0.001, algorithm='ball_tree',
2929
leaf_size=30, n_jobs=None):
3030
"""Perform OPTICS clustering from vector array
@@ -93,8 +93,10 @@ def optics(X, min_samples=5, max_eps=np.inf, metric='euclidean',
9393
significant_min : float, optional (default=.003)
9494
Sets a lower threshold on how small a significant maxima can be.
9595
96-
min_cluster_size_ratio : float, optional (default=.005)
97-
Minimum percentage of dataset expected for cluster membership.
96+
min_cluster_size : int > 1 or float between 0 and 1 (default=0.005)
97+
Minimum number of samples in an OPTICS cluster, expressed as an
98+
absolute number or a fraction of the number of samples (rounded
99+
to be at least 2).
98100
99101
min_maxima_ratio : float, optional (default=.001)
100102
Used to determine neighborhood size for minimum cluster membership.
@@ -151,7 +153,7 @@ def optics(X, min_samples=5, max_eps=np.inf, metric='euclidean',
151153
clust = OPTICS(min_samples, max_eps, metric, p, metric_params,
152154
maxima_ratio, rejection_ratio,
153155
similarity_threshold, significant_min,
154-
min_cluster_size_ratio, min_maxima_ratio,
156+
min_cluster_size, min_maxima_ratio,
155157
algorithm, leaf_size, n_jobs)
156158
clust.fit(X)
157159
return clust.core_sample_indices_, clust.labels_
@@ -221,8 +223,10 @@ class OPTICS(BaseEstimator, ClusterMixin):
221223
significant_min : float, optional (default=.003)
222224
Sets a lower threshold on how small a significant maxima can be.
223225
224-
min_cluster_size_ratio : float, optional (default=.005)
225-
Minimum percentage of dataset expected for cluster membership.
226+
min_cluster_size : int > 1 or float between 0 and 1 (default=0.005)
227+
Minimum number of samples in an OPTICS cluster, expressed as an
228+
absolute number or a fraction of the number of samples (rounded
229+
to be at least 2).
226230
227231
min_maxima_ratio : float, optional (default=.001)
228232
Used to determine neighborhood size for minimum cluster membership.
@@ -289,7 +293,7 @@ class OPTICS(BaseEstimator, ClusterMixin):
289293
def __init__(self, min_samples=5, max_eps=np.inf, metric='euclidean',
290294
p=2, metric_params=None, maxima_ratio=.75,
291295
rejection_ratio=.7, similarity_threshold=0.4,
292-
significant_min=.003, min_cluster_size_ratio=.005,
296+
significant_min=.003, min_cluster_size=.005,
293297
min_maxima_ratio=0.001, algorithm='ball_tree',
294298
leaf_size=30, n_jobs=None):
295299

@@ -299,7 +303,7 @@ def __init__(self, min_samples=5, max_eps=np.inf, metric='euclidean',
299303
self.rejection_ratio = rejection_ratio
300304
self.similarity_threshold = similarity_threshold
301305
self.significant_min = significant_min
302-
self.min_cluster_size_ratio = min_cluster_size_ratio
306+
self.min_cluster_size = min_cluster_size
303307
self.min_maxima_ratio = min_maxima_ratio
304308
self.algorithm = algorithm
305309
self.metric = metric
@@ -330,6 +334,24 @@ def fit(self, X, y=None):
330334
X = check_array(X, dtype=np.float)
331335

332336
n_samples = len(X)
337+
338+
if self.min_samples > n_samples:
339+
raise ValueError("Number of training samples (n_samples=%d) must "
340+
"be greater than min_samples (min_samples=%d) "
341+
"used for clustering." %
342+
(n_samples, self.min_samples))
343+
344+
if self.min_cluster_size <= 0 or (self.min_cluster_size !=
345+
int(self.min_cluster_size)
346+
and self.min_cluster_size > 1):
347+
raise ValueError('min_cluster_size must be a positive integer or '
348+
'a float between 0 and 1. Got %r' %
349+
self.min_cluster_size)
350+
elif self.min_cluster_size > n_samples:
351+
raise ValueError('min_cluster_size must be no greater than the '
352+
'number of samples (%d). Got %d' %
353+
(n_samples, self.min_cluster_size))
354+
333355
# Start all points as 'unprocessed' ##
334356
self.reachability_ = np.empty(n_samples)
335357
self.reachability_.fill(np.inf)
@@ -338,13 +360,6 @@ def fit(self, X, y=None):
338360
# Start all points as noise ##
339361
self.labels_ = np.full(n_samples, -1, dtype=int)
340362

341-
# Check for valid n_samples relative to min_samples
342-
if self.min_samples > n_samples:
343-
raise ValueError("Number of training samples (n_samples=%d) must "
344-
"be greater than min_samples (min_samples=%d) "
345-
"used for clustering." %
346-
(n_samples, self.min_samples))
347-
348363
nbrs = NearestNeighbors(n_neighbors=self.min_samples,
349364
algorithm=self.algorithm,
350365
leaf_size=self.leaf_size, metric=self.metric,
@@ -363,7 +378,7 @@ def fit(self, X, y=None):
363378
self.rejection_ratio,
364379
self.similarity_threshold,
365380
self.significant_min,
366-
self.min_cluster_size_ratio,
381+
self.min_cluster_size,
367382
self.min_maxima_ratio)
368383
self.core_sample_indices_ = indices_
369384
return self
@@ -492,7 +507,7 @@ def _extract_dbscan(ordering, core_distances, reachability, eps):
492507

493508
def _extract_optics(ordering, reachability, maxima_ratio=.75,
494509
rejection_ratio=.7, similarity_threshold=0.4,
495-
significant_min=.003, min_cluster_size_ratio=.005,
510+
significant_min=.003, min_cluster_size=.005,
496511
min_maxima_ratio=0.001):
497512
"""Performs automatic cluster extraction for variable density data.
498513
@@ -530,8 +545,10 @@ def _extract_optics(ordering, reachability, maxima_ratio=.75,
530545
significant_min : float, optional
531546
Sets a lower threshold on how small a significant maxima can be.
532547
533-
min_cluster_size_ratio : float, optional
534-
Minimum percentage of dataset expected for cluster membership.
548+
min_cluster_size : int > 1 or float between 0 and 1
549+
Minimum number of samples in an OPTICS cluster, expressed as an
550+
absolute number or a fraction of the number of samples (rounded
551+
to be at least 2).
535552
536553
min_maxima_ratio : float, optional
537554
Used to determine neighborhood size for minimum cluster membership.
@@ -551,7 +568,7 @@ def _extract_optics(ordering, reachability, maxima_ratio=.75,
551568
root_node = _automatic_cluster(reachability_plot, ordering,
552569
maxima_ratio, rejection_ratio,
553570
similarity_threshold, significant_min,
554-
min_cluster_size_ratio, min_maxima_ratio)
571+
min_cluster_size, min_maxima_ratio)
555572
leaves = _get_leaves(root_node, [])
556573
# Start cluster id's at 0
557574
clustid = 0
@@ -570,7 +587,7 @@ def _extract_optics(ordering, reachability, maxima_ratio=.75,
570587
def _automatic_cluster(reachability_plot, ordering,
571588
maxima_ratio, rejection_ratio,
572589
similarity_threshold, significant_min,
573-
min_cluster_size_ratio, min_maxima_ratio):
590+
min_cluster_size, min_maxima_ratio):
574591
"""Converts reachability plot to cluster tree and returns root node.
575592
576593
Parameters
@@ -582,13 +599,10 @@ def _automatic_cluster(reachability_plot, ordering,
582599
"""
583600

584601
min_neighborhood_size = 2
585-
min_cluster_size = int(min_cluster_size_ratio * len(ordering))
602+
if min_cluster_size <= 1:
603+
min_cluster_size = max(2, min_cluster_size * len(ordering))
586604
neighborhood_size = int(min_maxima_ratio * len(ordering))
587605

588-
# Should this check for < min_samples? Should this be public?
589-
if min_cluster_size < 5:
590-
min_cluster_size = 5
591-
592606
# Again, should this check < min_samples, should the parameter be public?
593607
if neighborhood_size < min_neighborhood_size:
594608
neighborhood_size = min_neighborhood_size

sklearn/cluster/tests/test_optics.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Amy X. Zhang <axz@mit.edu>
33
# License: BSD 3 clause
44

5+
from __future__ import print_function, division
56
import numpy as np
67
import pytest
78

@@ -20,6 +21,17 @@
2021
from sklearn.cluster.tests.common import generate_clustered_data
2122

2223

24+
rng = np.random.RandomState(0)
25+
n_points_per_cluster = 250
26+
C1 = [-5, -2] + .8 * rng.randn(n_points_per_cluster, 2)
27+
C2 = [4, -1] + .1 * rng.randn(n_points_per_cluster, 2)
28+
C3 = [1, -2] + .2 * rng.randn(n_points_per_cluster, 2)
29+
C4 = [-2, 3] + .3 * rng.randn(n_points_per_cluster, 2)
30+
C5 = [3, -2] + 1.6 * rng.randn(n_points_per_cluster, 2)
31+
C6 = [5, 6] + 2 * rng.randn(n_points_per_cluster, 2)
32+
X = np.vstack((C1, C2, C3, C4, C5, C6))
33+
34+
2335
def test_correct_number_of_clusters():
2436
# in 'auto' mode
2537

@@ -135,27 +147,36 @@ def test_dbscan_optics_parity(eps, min_samples):
135147

136148
def test_auto_extract_hier():
137149
# Tests auto extraction gets correct # of clusters with varying density
150+
clust = OPTICS(min_samples=9).fit(X)
151+
assert_equal(len(set(clust.labels_)), 6)
138152

139-
# Generate sample data
140-
rng = np.random.RandomState(0)
141-
n_points_per_cluster = 250
142153

143-
C1 = [-5, -2] + .8 * rng.randn(n_points_per_cluster, 2)
144-
C2 = [4, -1] + .1 * rng.randn(n_points_per_cluster, 2)
145-
C3 = [1, -2] + .2 * rng.randn(n_points_per_cluster, 2)
146-
C4 = [-2, 3] + .3 * rng.randn(n_points_per_cluster, 2)
147-
C5 = [3, -2] + 1.6 * rng.randn(n_points_per_cluster, 2)
148-
C6 = [5, 6] + 2 * rng.randn(n_points_per_cluster, 2)
149-
X = np.vstack((C1, C2, C3, C4, C5, C6))
154+
# try arbitrary minimum sizes
155+
@pytest.mark.parametrize('min_cluster_size', range(2, X.shape[0] // 10, 23))
156+
def test_min_cluster_size(min_cluster_size):
157+
redX = X[::10] # reduce for speed
158+
clust = OPTICS(min_samples=9, min_cluster_size=min_cluster_size).fit(redX)
159+
cluster_sizes = np.bincount(clust.labels_[clust.labels_ != -1])
160+
if cluster_sizes.size:
161+
assert min(cluster_sizes) >= min_cluster_size
162+
# check behaviour is the same when min_cluster_size is a fraction
163+
clust_frac = OPTICS(min_samples=9,
164+
min_cluster_size=min_cluster_size / redX.shape[0])
165+
clust_frac.fit(redX)
166+
assert_array_equal(clust.labels_, clust_frac.labels_)
150167

151-
# Compute OPTICS
152168

153-
clust = OPTICS(min_samples=9)
169+
@pytest.mark.parametrize('min_cluster_size', [0, -1, 1.1, 2.2])
170+
def test_min_cluster_size_invalid(min_cluster_size):
171+
clust = OPTICS(min_cluster_size=min_cluster_size)
172+
with pytest.raises(ValueError, match="must be a positive integer or a "):
173+
clust.fit(X)
154174

155-
# Run the fit
156-
clust.fit(X)
157175

158-
assert_equal(len(set(clust.labels_)), 6)
176+
def test_min_cluster_size_invalid2():
177+
clust = OPTICS(min_cluster_size=len(X) + 1)
178+
with pytest.raises(ValueError, match="must be no greater than the "):
179+
clust.fit(X)
159180

160181

161182
@pytest.mark.parametrize("reach, n_child, members", [
@@ -187,23 +208,7 @@ def test_cluster_sigmin_pruning(reach, n_child, members):
187208
def test_reach_dists():
188209
# Tests against known extraction array
189210

190-
rng = np.random.RandomState(0)
191-
n_points_per_cluster = 250
192-
193-
C1 = [-5, -2] + .8 * rng.randn(n_points_per_cluster, 2)
194-
C2 = [4, -1] + .1 * rng.randn(n_points_per_cluster, 2)
195-
C3 = [1, -2] + .2 * rng.randn(n_points_per_cluster, 2)
196-
C4 = [-2, 3] + .3 * rng.randn(n_points_per_cluster, 2)
197-
C5 = [3, -2] + 1.6 * rng.randn(n_points_per_cluster, 2)
198-
C6 = [5, 6] + 2 * rng.randn(n_points_per_cluster, 2)
199-
X = np.vstack((C1, C2, C3, C4, C5, C6))
200-
201-
# Compute OPTICS
202-
203-
clust = OPTICS(min_samples=10, metric='minkowski')
204-
205-
# Run the fit
206-
clust.fit(X)
211+
clust = OPTICS(min_samples=10, metric='minkowski').fit(X)
207212

208213
# Expected values, matches 'RD' results from:
209214
# http://chemometria.us.edu.pl/download/optics.py

0 commit comments

Comments
 (0)
0