8000 Fix check_array dtype in MinibatchKMeans.partial_fit (#14323) · scikit-learn/scikit-learn@bf8eff3 · GitHub
[go: up one dir, main page]

Skip to content

Commit bf8eff3

Browse files
rthjeremiedbb
authored andcommitted
Fix check_array dtype in MinibatchKMeans.partial_fit (#14323)
1 parent 36b66ba commit bf8eff3

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

sklearn/cluster/k_means_.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,8 +1422,8 @@ class MiniBatchKMeans(KMeans):
14221422
>>> kmeans = kmeans.partial_fit(X[0:6,:])
14231423
>>> kmeans = kmeans.partial_fit(X[6:12,:])
14241424
>>> kmeans.cluster_centers_
1425-
array([[1, 1],
1426-
[3, 4]])
1425+
array([[2. , 1. ],
1426+
[3.5, 4.5]])
14271427
>>> kmeans.predict([[0, 0], [4, 4]])
14281428
array([0, 1], dtype=int32)
14291429
>>> # fit on the whole data
@@ -1667,7 +1667,8 @@ def partial_fit(self, X, y=None, sample_weight=None):
16671667
16681668
"""
16691669

1670-
X = check_array(X, accept_sparse="csr", order="C")
1670+
X = check_array(X, accept_sparse="csr", order="C",
1671+
dtype=[np.float64, np.float32])
16711672
n_samples, n_features = X.shape
16721673
if hasattr(self.init, '__array__'):
16731674
self.init = np.ascontiguousarray(self.init, dtype=X.dtype)

sklearn/cluster/tests/test_k_means.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,11 @@ def test_k_means_empty_cluster_relocated():
942942

943943
assert len(set(km.labels_)) == 2
944944
assert_allclose(km.cluster_centers_, [[-1], [1]])
945+
946+
947+
def test_minibatch_kmeans_partial_fit_int_data():
948+
# Issue GH #14314
949+
X = np.array([[-1], [1]], dtype=np.int)
950+
km = MiniBatchKMeans(n_clusters=2)
951+
km.partial_fit(X)
952+
assert km.cluster_centers_.dtype.kind == "f"

0 commit comments

Comments
 (0)
0