8000 added tests for weighted kmeans/minibatch · scikit-learn/scikit-learn@820549d · GitHub
[go: up one dir, main page]

Skip to content

Commit 820549d

Browse files
author
bhsu
committed
added tests for weighted kmeans/minibatch
fixed typos fix compatability issues
1 parent a54fe7f commit 820549d

File tree

2 files changed

+111
-13
lines changed

2 files changed

+111
-13
lines changed

sklearn/cluster/k_means_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto',
273273
raise ValueError("Invalid number of initializations."
274274
" n_init=%d must be bigger than zero." % n_init)
275275
random_state = check_random_state(random_state)
276-
sample_weight = check_sample_weight(X, sampled_weight)
276+
sample_weight = check_sample_weight(X, sample_weight)
277277
best_inertia = np.infty
278278
X = as_float_array(X, copy=copy_x)
279279
tol = _tolerance(X, tol)

sklearn/cluster/tests/test_k_means.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from sklearn.cluster import KMeans, k_means
2323
from sklearn.cluster import MiniBatchKMeans
2424
from sklearn.cluster.k_means_ import _labels_inertia
25+
from sklearn.cluster.k_means_ import _kmeans_single
26+
from sklearn.cluster.k_means_ import _k_init
2527
from sklearn.cluster.k_means_ import _mini_batch_step
2628
from sklearn.datasets.samples_generator import make_blobs
2729
from sklearn.externals.six.moves import cStringIO as StringIO
@@ -38,6 +40,99 @@
3840
X, true_labels = make_blobs(n_samples=n_samples, centers=centers,
3941
cluster_std=1., random_state=42)
4042
X_csr = sp.csr_matrix(X)
43+
sample_unweight = np.ones((n_samples,))
44+
45+
46+
def test_weighted_kmeans():
47+
clf_a = KMeans(random_state = 42).fit(X[:75])
48+
unweighted_labels = clf_a.predict(X[75:])
49+
clf_b = KMeans(random_state = 42).fit(X[:75], sample_weight=sample_unweight[:75])
50+
weighted_labels = clf_b.predict(X[75:])
51+
assert_array_equal(weighted_labels, unweighted_labels)
52+
53+
clf_c = KMeans(random_state=42).fit(X_csr[:75], sample_weight=sample_unweight[:75])
54+
sparse_weighted_labels = clf_b.predict(X[75:])
55+
assert_array_equal(sparse_weighted_labels, unweighted_labels)
56+
57+
58+
def test_weighted_kmeans_weight_flag():
59+
km = KMeans()
60+
assert_raises(TypeError, km.fit, X, sample_weight=1)
61+
assert_raises(ValueError, km.fit, X, sample_weight=np.ones((n_samples+1,)))
62+
assert_raises(ValueError, km.fit, X, sample_weight=np.ones((n_samples,1)))
63+
64+
65+
def test_weighted_kmeans_single():
66+
# Test that given the same initial conditions, a single step of kmeans gives same
67+
# center with trivial default weights and non-trivial constant weight
68+
# Checks that total inertia scales with the weight
69+
scale = 2
70+
x_squared_norms = row_norms(X, squared=True)
71+
72+
labels_a, inertia_a, centers_a, _ = _kmeans_single(
73+
X, n_clusters, x_squared_norms, init='random',\
74+
random_state=42, sample_weight=sample_unweight)
75+
76+
labels_b, inertia_b , centers_b, _ = _kmeans_single(
77+
X, n_clusters, x_squared_norms, init='random',\
78+
random_state=42, sample_weight=scale*sample_unweight)
79+
assert_array_equal(labels_a, labels_b)
80+
assert_equal(inertia_b, scale * inertia_a)
81+
assert_array_almost_equal(centers_a, centers_b)
82+
83+
### Tests cython implementation of assign_labels
84+
labels_c, inertia_c , centers_c, _ = _kmeans_single(
85+
X, n_clusters, x_squared_norms, \
86+
init='random',\
87+
precompute_distances = False, random_state=42, \
88+
sample_weight=scale*sample_unweight)
89+
assert_array_equal(labels_c, labels_b)
90+
assert_array_almost_equal(centers_c, centers_b)
91+
assert_almost_equal(inertia_b, inertia_c)
92+
93+
94+
def test_weighted_kmeans_single_sparse():
95+
# Test that given the same initial conditions, a single step of kmeans gives same
96+
# center with trivial default weights and non-trivial constant weight
97+
# Checks that total inertia scales with the weight
98+
scale = 2
99+
x_squared_norms = row_norms(X_csr, squared=True)
100+
101+
102+
labels_a, inertia_a, centers_a, _ = _kmeans_single(
103+
X_csr, n_clusters, x_squared_norms, init='random',\
104+
random_state=42, sample_weight=sample_unweight)
105+
106+
labels_b,inertia_b, centers_b, _ = _kmeans_single(
107+
X_csr, n_clusters, x_squared_norms, init='random',\
108+
9E81 random_state=42, sample_weight=scale*sample_unweight)
109+
assert_array_equal(labels_a, labels_b)
110+
assert_almost_equal(inertia_b, scale * inertia_a)
111+
assert_array_almost_equal(centers_a, centers_b)
112+
113+
### Tests cython implementation of assign_labels
114+
labels_c, inertia_c , centers_c, _ = _kmeans_single(
115+
X_csr, n_clusters, x_squared_norms, init='random',\
116+
precompute_distances = False, random_state=42,\
117+
sample_weight=scale*sample_unweight)
118+
assert_array_equal(labels_c, labels_b)
119+
assert_array_almost_equal(centers_c, centers_b)
120+
assert_almost_equal(inertia_b, inertia_c)
121+
122+
123+
def test_weighted_k_init():
124+
# Trivial weight and non-trivial constant weight should give the same
125+
# centers under identical seeds
126+
seed = 1234
127+
x_squared_norms = row_norms(X, squared=True)
128+
centers_a = _k_init(X, 2, x_squared_norms, \
129+
random_state=np.random.RandomState(seed), \
130+
sample_weight=sample_unweight)
131+
132+
centers_b = _k_init(X, 2, x_squared_norms, \
133+
random_state=np.random.RandomState(seed), \
134+
sample_weight=sample_unweight*2)
135+
assert_array_almost_equal(centers_a, centers_b)
41136

42137

43138
def test_kmeans_dtype():
@@ -68,14 +163,14 @@ def test_labels_assignment_and_inertia():
68163
# perform label assignment using the dense array input
69164
x_squared_norms = (X ** 2).sum(axis=1)
70165
labels_array, inertia_array = _labels_inertia(
71-
X, x_squared_norms, noisy_centers)
166+
X, x_squared_norms, noisy_centers, sample_weight=sample_unweight)
72167
assert_array_almost_equal(inertia_array, inertia_gold)
73168
assert_array_equal(labels_array, labels_gold)
74169

75170
# perform label assignment using the sparse CSR input
76171
x_squared_norms_from_csr = row_norms(X_csr, squared=True)
77172
labels_csr, inertia_csr = _labels_inertia(
78-
X_csr, x_squared_norms_from_csr, noisy_centers)
173+
X_csr, x_squared_norms_from_csr, noisy_centers, sample_weight=sample_unweight)
79174
assert_array_almost_equal(inertia_csr, inertia_gold)
80175
assert_array_equal(labels_csr, labels_gold)
81176

@@ -88,8 +183,8 @@ def test_minibatch_update_consistency():
88183
new_centers = old_centers.copy()
89184
new_centers_csr = old_centers.copy()
90185

91-
counts = np.zeros(new_centers.shape[0], dtype=np.int32)
92-
counts_csr = np.zeros(new_centers.shape[0], dtype=np.int32)
186+
counts = np.zeros(new_centers.shape[0], dtype=np.float64)
187+
counts_csr = np.zeros(new_centers.shape[0], dtype=np.float64)
93188

94189
x_squared_norms = (X ** 2).sum(axis=1)
95190
x_squared_norms_csr = row_norms(X_csr, squared=True)
@@ -102,16 +197,17 @@ def test_minibatch_update_consistency():
102197
X_mb_csr = X_csr[:10]
103198
x_mb_squared_norms = x_squared_norms[:10]
104199
x_mb_squared_norms_csr = x_squared_norms_csr[:10]
200+
sample_unweight_mb = sample_unweight[:10]
105201

106202
# step 1: compute the dense minibatch update
107203
old_inertia, incremental_diff = _mini_batch_step(
108204
X_mb, x_mb_squared_norms, new_centers, counts,
109-
buffer, 1, None, random_reassign=False)
205+
buffer, 1, None, random_reassign=False, sample_weight=sample_unweight_mb)
110206
assert_greater(old_inertia, 0.0)
111207

112208
# compute the new inertia on the same batch to check that it decreased
113209
labels, new_inertia = _labels_inertia(
114-
X_mb, x_mb_squared_norms, new_centers)
210+
X_mb, x_mb_squared_norms, new_centers, sample_weight=sample_unweight_mb)
115211
assert_greater(new_inertia, 0.0)
116212
assert_less(new_inertia, old_inertia)
117213

@@ -123,12 +219,12 @@ def test_minibatch_update_consistency():
123219
# step 2: compute the sparse minibatch update
124220
old_inertia_csr, incremental_diff_csr = _mini_batch_step(
125221
X_mb_csr, x_mb_squared_norms_csr, new_centers_csr, counts_csr,
126-
buffer_csr, 1, None, random_reassign=False)
222+
buffer_csr, 1, None, random_reassign=False, sample_weight=sample_unweight_mb)
127223
assert_greater(old_inertia_csr, 0.0)
128224

129225
# compute the new inertia on the same batch to check that it decreased
130226
labels_csr, new_inertia_csr = _labels_inertia(
131-
X_mb_csr, x_mb_squared_norms_csr, new_centers_csr)
227+
X_mb_csr, x_mb_squared_norms_csr, new_centers_csr, sample_weight=sample_unweight_mb)
132228
assert_greater(new_inertia_csr, 0.0)
133229
assert_less(new_inertia_csr, old_inertia_csr)
134230

@@ -421,7 +517,7 @@ def test_minibatch_with_many_reassignments():
421517

422518
def test_sparse_mb_k_means_callable_init():
423519

424-
def test_init(X, k, random_state):
520+
def test_init(X, k, random_state, sample_weight=None):
425521
return centers
426522

427523
# Small test to check that giving the wrong number of centers
@@ -655,7 +751,8 @@ def test_k_means_function():
655751
sys.stdout = StringIO()
656752
try:
657753
cluster_centers, labels, inertia = k_means(X, n_clusters=n_clusters,
658-
verbose=True)
754+
verbose=True,
755+
sample_weight=sample_unweight)
659756
finally:
660757
sys.stdout = old_stdout
661758
centers = cluster_centers
@@ -670,7 +767,8 @@ def test_k_means_function():
670767

671768
# check warning when centers are passed
672769
assert_warns(RuntimeWarning, k_means, X, n_clusters=n_clusters,
673-
init=centers)
770+
init=centers, sample_weight=sample_unweight)
674771

675772
# to many clusters desired
676-
assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1)
773+
assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, sample_weight=sample_unweight)
774+

0 commit comments

Comments
 (0)
0