8000 Bisecting Kmeans fails to bisect a certain cluster · Issue #25505 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Bisecting Kmeans fails to bisect a certain cluster #25505
Closed
@Xunius

Description

@Xunius

Describe the bug

Hi all,

I'm using the sklearn.cluster.BisectingKMeans to perform a clustering, and it worked for a range of k values, until it failed at k=9 (I don't think the k-value is important though). The issue seems to be that it failed to split a cluster into 2, but got 2 identical centers instead. More details are given below.

Steps/Code to Reproduce

I'm afraid I can't provide a minimal reproducible example without relying on external data. I've uploaded the training data to an empty repo: https://github.com/Xunius/dummy_repo

The only python script in the repo has only a few lines, but I'm assuming the npz data file is at the same folder as the py script. The code to reproduce the bug is:

import numpy as np
from sklearn.cluster import BisectingKMeans

npz = np.load('./bisectkmeans.npz')

km = BisectingKMeans(n_clusters=9, init='random', max_iter=400,
                     n_init=10, random_state=10,
                     bisecting_strategy='largest_cluster')
X = npz['data']
weights = npz['weights']
km.fit(X, None, sample_weight=weights)

Expected Results

No error is thrown

Actual Results

Below is the Traceback info:

Traceback (most recent call last):
  File "bisectkmeans.py", line 12, in <module>
    km.fit(X, None, sample_weight=weights)
  File "/home/guangzhi/.local/mambaforge/envs/cdat/lib/python3.8/site-packages/sklearn/cluster/_bisect_k_means.py", line 411, in fit
    self._bisect(X, x_squared_norms, sample_weight, cluster_to_bisect)
  File "/home/guangzhi/.local/mambaforge/envs/cdat/lib/python3.8/site-packages/sklearn/cluster/_bisect_k_means.py", line 342, in _bisect
    cluster_to_bisect.split(best_labels, best_centers, scores)
  File "/home/guangzhi/.local/mambaforge/envs/cdat/lib/python3.8/site-packages/sklearn/cluster/_bisect_k_means.py", line 45, in split
    indices=self.indices[labels == 1], center=centers[1], score=scores[1]
IndexError: index 1 is out of bounds for axis 0 with size 1

Printing np.unique(labels) gives:

array([0], dtype=int32)

so the labels array is a length-6843 1d array with all 0s.

The centers array is a 2x7 array, with 2 identical rows:

(Pdb) centers
array([[-1.04880691,  1.47209268,  0.92988155,  1.16810444,  0.98679651,
        -1.4344821 , -0.70694216],
       [-1.04880691,  1.47209268,  0.92988155,  1.16810444,  0.98679651,
        -1.4344821 , -0.70694216]])

Going up 1 level in the stack and at Line 342 (cluster_to_bisect.split(best_labels, best_centers, scores)), it is noticed that the best_inertia variable is a tiny value:

(Pdb) best_inertia
5.7262640016675594e-30

and the array X is a (6843, 7)` 2d array, with identical rows:

(Pdb) np.ptp(X,axis=0)
array([0., 0., 0., 0., 0., 0., 0.])

I'm not sure what's causing this failure, just in case it helps, I'm using some sample weights, and when it failed, I have some quite small weights:

(Pdb) sample_weight.shape
(6843,)
(Pdb) sample_weight.min()
1.1269358725791912e-09
(Pdb) sample_weight.max()
1.418544396650835e-05
(Pdb) sample_weight.sum()
0.03364984267143782

Versions

System:
    python: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:46:39)  [GCC 10.4.0]
executable: /home/guangzhi/.local/mambaforge/envs/cdat/bin/python
   machine: Linux-5.15.85-1-MANJARO-x86_64-with-glibc2.10

Python dependencies:
      sklearn: 1.2.0
          pip: 22.3.1
   setuptools: 66.0.0
        numpy: 1.21.6
        scipy: 1.10.0
       Cython: 0.29.33
       pandas: 1.5.3
   matplotlib: 3.2.2
       joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/guangzhi/.local/mambaforge/envs/cdat/lib/libopenblasp-r0.3.21.so
        version: 0.3.21
threading_layer: pthreads
   architecture: Haswell
    num_threads: 8

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: /home/guangzhi/.local/mambaforge/envs/cdat/lib/libgomp.so.1.0.0
        version: None
    num_threads: 8

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0