8000 K-Means clustering performance improvements · Issue #10744 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

K-Means clustering performance improvements #10744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
FrancoisFayard opened this issue Mar 2, 2018 · 15 comments · Fixed by #11950
Closed

K-Means clustering performance improvements #10744

FrancoisFayard opened this issue Mar 2, 2018 · 15 comments · Fixed by #11950

Comments

@FrancoisFayard
Copy link
FrancoisFayard commented Mar 2, 2018

Hi,

I am new to the Scikit-Learn world, but I've talked a lot with Alexandre Gramfort about some performance enhancement that can be done to Scikit-Learn. I knew for a long time that the KMeans clustering code was suboptimal, but my version was a C++ code and it seems that it is not a language that is used in Scikit-Learn. Wether or not it was possible to make the KMeans clustering implementation of Scikit-Learn faster with the constraints of the community was an open question to me as I am not a Python programmer.

I am going to focus on Intel architectures. The benchmarks I've done are in between the vanilla Scikit-Learn (the latest one on Anaconda), Intel Scikit-Learn (the latest one from the intel channel which is linked to the DAAL) and a code I've written using Cython and compiled with the Intel compiler.

Here are the difference of speed to classify 270 000 points in 1024 clusters in dimension 3. This example comes from the color quantization documentation of Scikit-Learn which has an image of 270 000 pixels. The reference speed is the vanilla Scikit-Learn on a MacBook Pro with 4 cores. Any higher number represents a speedup over this configuration:

MacBookPro: 4 cores, Haswell (AVX2)
Scikit-Learn - Vanilla: 1
Scikit-Learn - Intel: x 52
InsideLoop - Icc 18: x 234 (About 129 GFlops/s)

Dual Xeon: 36 cores, Skylake (AVX512)
Scikit-Learn - Vanilla: x 2
Scikit-Learn - Intel: x 85
InsideLoop - Gcc 7.3: x 276
InsideLoop - Icc 18: x 1382 (About 754 GFlops/s)

This shows that some huge improvements can be done, even above the Intel version. There are a few things where my code has an advantage:

  • It knows at compile time the dimension, which is 3 here
  • It turns out that kmeans clustering with low dimension is really bad for both the vanilla and Intel version of Scikit-Learn
  • It has been compiled with the Intel compiler for my platform. Therefore, I have asked the compiler to generate AVX2 instructions on the MacBook Pro, and AVX512 instructions on the Xeon Skylake

Here is the code if you are interested. A few things are missing such as the handling of empty clusters but I have never seen that being a hotspot for k-means clustering. Also, bare in mind that I am more of a C/C++ programmer. This is my first Cython code and there might be things that are not really "appropriate". To get the speedup, I leverage both parallelisation (using OpenMP) and vectorization (thanks to memoryview and the Intel compiler).

import time
import numpy as np
import sklearn.cluster
from sklearn.cluster import KMeans
from sklearn.datasets import load_sample_image
from cython.parallel import prange
cimport cython


cdef int width = 427
cdef int height = 640
cdef int nb_points = width * height
cdef int nb_clusters = 1024
cdef int nb_iterations_insideloop = 100
cdef int nb_iterations_scikitlearn = 3
cdef int nb_iterations_intel = 20

pixel = np.random.rand(width * height, 3).astype('float32')

pr = np.array(pixel[:, 0])
pg = np.array(pixel[:, 1])
pb = np.array(pixel[:, 2])
cluster = np.arange(0, nb_points, dtype = 'int32')
cluster = cluster % nb_clusters
cr = np.zeros(nb_clusters, dtype = 'float32')
cg = np.zeros(nb_clusters, dtype = 'float32')
cb = np.zeros(nb_clusters, dtype = 'float32')
pop = np.zeros(nb_clusters, dtype = 'int32')

cdef float[::1] pr_view = pr
cdef float[::1] pg_view = pg
cdef float[::1] pb_view = pb
cdef int[::1] cluster_view = cluster
cdef float[::1] cr_view = cr
cdef float[::1] cg_view = cg
cdef float[::1] cb_view = cb
cdef int[::1] pop_view = pop

t0 = time.time()
kmeans(pr_view, pg_view, pb_view,
       cluster_view,
       cr_view, cg_view, cb_view,
       pop_view, nb_iterations_insideloop)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_insideloop * nb_clusters * 8 / (t1 - t0)
print('  Time for KMeans clustering, InsideLoop: {} s'.format((t1 - t0) / nb_iterations_insideloop))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = sklearn.cluster.k_means(pixel, nb_clusters, init = 'random',
    n_init = 1, tol = 1.0e-16, max_iter = nb_iterations_scikitlearn, return_n_iter = True)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_scikitlearn * nb_clusters * 8 / (t1 - t0)
print('Time for KMeans clustering, Scikit-Learn: {} s'.format((t1 - t0) / nb_iterations_scikitlearn))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = KMeans(init = 'random', n_init = 1, tol = 1.0e-16, n_clusters = nb_clusters,
    random_state = 0, max_iter = nb_iterations_intel).fit(pixel)
t1 = time.time()
gflops = 1.0e-9 * width * height * nb_iterations_intel * nb_clusters * 8 / (t1 - t0)
print('       Time for KMeans clustering, Intel: {} s'.format((t1 - t0) / nb_iterations_scikitlearn))
print('                             Performance: {} Glops/s'.format(gflops))


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void kmeans(float[::1] pr, float[:
8000
:1] pg, float[::1] pb,
                  int[::1] cluster,
                  float[::1] cr, float[::1] cg, float[::1] cb,
                  int[::1] pop, int nb_iterations) nogil:
  cdef int m = len(cluster) # m is the number of points
  cdef int n = len(pop)     # n is the number of clusters
  cdef int i, j, k, best_j
  cdef float alpha, distance, best_distance
  cdef float x, y, z

  for k in range(nb_iterations):
    for j in range(n):
      cr[j] = 0.0
      cg[j] = 0.0
      cb[j] = 0.0
      pop[j] = 0
    for i in range(m):
      j = cluster[i]
      cr[j] += pr[i]
      cg[j] += pg[i]
      cb[j] += pb[i]
      pop[j] += 1
    for j in range(n):
      if pop[j] > 0:
        alpha = 1.0 / pop[j]
        cr[j] *= alpha
        cg[j] *= alpha
        cb[j] *= alpha

    for i in prange(m):
      best_distance = 3.0 + 1.0
      best_j = 0
      for j in range(n):
        x = pr[i] - cr[j]
        y = pg[i] - cg[j]
        z = pb[i] - cb[j]
        distance = x * x + y * y + z * z
        if distance < best_distance:
          best_distance = distance
          best_j = j
      cluster[i] = best_j
@FrancoisFayard FrancoisFayard changed the title KMeans Performance K-Means clustering performance improvements Mar 2, 2018
@jnothman
Copy link
Member
jnothman commented Mar 2, 2018 via email

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 2, 2018

One can get large performance improvement in any dimension. Here is a code running with:
nb_points = 100000
nb_clusters = 1024
dim = 128

MacBookPro: 4 cores, Haswell (AVX2)
Scikit-Learn - Vanilla: 1
Scikit-Learn - Intel: x 49
InsideLoop - Icc 18: x 52 (About 82 GFlops/s)

Dual Xeon: 36 cores, Skylake (AVX512)
Scikit-Learn - Vanilla: x 0.8
Scikit-Learn - Intel: x 23
InsideLoop - Clang 5: x 47
InsideLoop - Gcc 7.3: x 43
InsideLoop - Icc 18: x 84 (About 140 GFlops/s)

There is still room for improvement. But even in large dimension, one can increase substantially the performance.

import time
import numpy as np
import sklearn.cluster
from sklearn.cluster import KMeans
from cython.parallel import prange
cimport cython

cdef int nb_points = 100000
cdef int nb_clusters = 1024
cdef int dim = 128
cdef int nb_iterations_insideloop = 1
cdef int nb_iterations_scikitlearn = 1
cdef int nb_iterations_intel = 1

point_il = np.random.rand(nb_points * dim).astype('float32')
point = np.random.rand(nb_points, dim).astype('float32')

cluster = np.arange(0, nb_points, dtype = 'int32') % nb_clusters

centroid_il = np.zeros(nb_clusters * dim, dtype = 'float32')
pop_il = np.zeros(nb_clusters, dtype = 'int32')
cdef float[::1] point_view = point_il
cdef int[::1] cluster_view = cluster
cdef float[::1] centroid_view = centroid_il
cdef int[::1] pop_view = pop_il

t0 = time.time()
kmeans(dim, point_view, cluster_view, centroid_view, pop_view, nb_iterations_insideloop)
t1 = time.time()
gflops = 1.0e-9 * nb_points * nb_iterations_insideloop * nb_clusters * 3 * dim / (t1 - t0)
print('  Time for KMeans clustering, InsideLoop: {} s'.format((t1 - t0) / nb_iterations_insideloop))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = sklearn.cluster.k_means(point, nb_clusters, init = 'random',
    n_init = 1, tol = 1.0e-16, max_iter = nb_iterations_scikitlearn, return_n_iter = True)
t1 = time.time()
gflops =  1.0e-9 * nb_points * nb_iterations_scikitlearn * nb_clusters * 3 * dim / (t1 - t0)
print('Time for KMeans clustering, Scikit-Learn: {} s'.format((t1 - t0) / nb_iterations_scikitlearn))
print('                             Performance: {} Glops/s'.format(gflops))

t0 = time.time()
res = KMeans(init = 'random', n_init = 1, tol = 1.0e-16, n_clusters = nb_clusters,
    random_state = 0, max_iter = nb_iterations_intel).fit(point)
t1 = time.time()
gflops = 1.0e-9 * nb_points * nb_iterations_intel * nb_clusters * 3 * dim / (t1 - t0)
print('       Time for KMeans clustering, Intel: {} s'.format((t1 - t0) / nb_iterations_intel))
print('                             Performance: {} Glops/s'.format(gflops))


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void kmeans(int dim,
                  float[::1] point, int[::1] cluster,
                  float[::1] centroid, int[::1] pop,
                  int nb_iterations) nogil:
  cdef int m = len(point) / dim    # m is the number of points
  cdef int n = len(centroid) / dim # n is the number of clusters
  cdef int i, j, k, l, best_j
  cdef float alpha, x, distance, best_distance

  for k in range(nb_iterations):
    for j in range(n):
      for l in range(dim):
        centroid[j * dim + l] = 0.0
      pop[j] = 0
    for i in range(m):
      j = cluster[i]
      for l in range(dim):
        centroid[j * dim + l] += point[i * dim + l]
      pop[j] += 1
    for j in range(n):
      if pop[j] > 0:
        alpha = 1.0 / pop[j]
        for l in range(dim):
          centroid[j * dim + l] *= alpha

    for i in prange(m):
      best_distance = dim + 1.0
      best_j = -1
      for j in range(n):
        distance = 0.0
        for l in range(dim):
          x = point[i * dim + l] - centroid[j * dim + l]
          distance = distance + x * x
        if distance < best_distance:
          best_distance = distance
          best_j = j
      cluster[i] = best_j

@rth
Copy link
Member
rth commented Mar 2, 2018

Interesting benchmarks, thank you.

To use something like that, there are currently two blockers, as far as I know: one is OpenMP usage (see #7650) and the other is memoryviews (#10624 , also related #10663, and probably some older issues) . Looks like both could get resolved; @ogrisel and @lesteve would know more about it...

There are a few things where my code has an advantage:
It knows at compile time the dimension, which is 3 here

Wouldn't that be an unfair advantage for these benchmarks, depending on what the compiler optimizes for? Maybe importing the cython function from a pure python script and running the timings there would safer, just to be sure?

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 2, 2018

Concerning the case where the dimension is known at compile time, it was just to show that this information could lead to large speedups. It was the first code I have tried, but it is not the main point on which I would like to focus.

Concerning OpenMP, I don't know how you handle multithreading in Scikit-Learn (apart from using a multithreaded BLAS). I've heard from Alexandre Gramfort that OpenMP might cause problems as some compilers (think Apple and their version of clang) do not support OpenMP out of the box. I know that Intel has been working hard for their Python distribution to integrate TBB (Threading Building Blocks) into their Python. It is an open source library (Apache 2), works on many platforms (including ARM), and has the huge advantage of handling nested parallelism. Moreover, Intel MKL support both OpenMP and TBB. I don't know if this is an option you have considered.

Concerning memoryviews, it is extremely important if you want vectorization. With Numpy, the problem is that the strides can be whatever you want. As far as I know, there is no guarantee that a[i] and a[i + 1] are next to each other in memory. This alone just kills performance and prevents vectorization when writing loops. With AVX-512 and 32-bit floating point numbers, you are leaving a potential performance factor of 16. Even on your laptop with AVX2, this is a performance factor of 8.

As I said, I am a newcomer to the Python community and I don't know well your requirements and problems but I would be glad to learn more and I have some good knowledge of multithreading and vectorization on CPUs. I can help and give you a hand on those subjects.

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 2, 2018

Is is also interesting to notice that multithreading and vectorization does not explain everything. On my MacBookPro with 4 cores and AVX2 (256 bits with 32 bits floats), we could hope for a theoretical speedup of factor of 4 x (256 / 32) = 32. Usually, you are happy when you get a speedup of 16 in those cases. Here, the speedup is 234. So, not only you are not exploiting multithreading and vectorization correctly, but also you are doing something that make things worse.

In your current implementation, the hotspot is here:

for sample_idx in range(n_samples):
        min_dist = -1
        for center_idx in range(n_clusters):
            dist = 0.0
            # hardcoded: minimize euclidean distance to cluster center:
            # ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
            dist += dot(n_features, &X[sample_idx, 0], x_stride,
                        &centers[center_idx, 0], center_stride)
            dist *= -2
            dist += center_squared_norms[center_idx]
            dist += x_squared_norms[sample_idx]
            if min_dist == -1 or dist < min_dist:
                min_dist = dist
                labels[sample_idx] = center_idx

        if store_distances:
            distances[sample_idx] = min_dist
        inertia += min_dist

This code is the hotspot of your current implementation in the file _k_means.pyx. It is quite interesting as everything is here to slow down the code:

  • Multithreading: There is no multithreading in this code, except for the dot function that calls the BLAS that is linked to Scikit-Learn. This is a BLAS level 1 operation that compute the scalar product of 2 vectors. Those operations are bandwidth bound and throwing many threads at this task will make things slower than just using one thread. I am sure the MKL has some mechanism to limit this kind of problem, but when you call dot, you are really asking for some parallelism.
  • Vectorization: When you call dot, your BLAS is going to call a vectorized code. For fp32 on AVX2 as on most platforms today, it will process elements by batch of 8. Unfortunately, all BLAS, including the MKL, are optimized for large arrays. As a consequence, in order to limit false sharing, BLAS try to first peel the loop and align arrays to a cacheline (512 bits). Moreover, a scalar product is a reduction and you really need to have large arrays to feel the speed of vectorization. I would not hope anything good before dimension 64.

Now imagine the disaster when you call dot to make a scalar product of vectors in dimension 3 : you are slowing things down.

  • (a-b)^2 = a^2 +b^2 - 2 a b: I have seen this trick many times, and I never really understood it. There are cases where it makes sense in languages such as Matlab and Python where loops are slow. You can use it when you have n vectors a and m vectors b in dimension k. The scalar product a b can be computed for all these elements as a large matrix multiplication, which make you feel good if you have a good BLAS. Using that for only two vectors will not only slow things down, it will also make your algorithm numerically instable as a^2 + b^2 - 2 a b can be negative because of round-off errors. I guess this is the reason why you tend to center all your points around 0 at the beginning of your code. With fp16 floating points around the corner, I believe that this code can be improved.

As a consequence, I believe that many things can be done to make your code faster and more stable. I would do, in order:

  • Remove the (a-b)^2 trick
  • Remove the call to a BLAS and replace it with pure Cython with memoryview
    Most compilers will not vectorize the code, and I believe that this is good news for low dimensions (less than 20 I would say).

Now comes the question of multithreading. There is a say in HPC: vectorize the innermost loop and use multithreading for the outermost loop. Unfortunately, the Python language push people to let the multithreading being handled by low level library and therefore use multithreading in innermost loops. If there is a loop that needs multithreading here, it is for sample_idx in range(n_samples):. But if you think more out of the box, the best loop for multithreading is the loop that starts the kmeans clustering with different initializations. This is where multithreading should be done first, in the Python code, not in the Cython one. If you have a multithreading library that can handle nested parallelism (such as TBB, and unlike OpenMP), you can multithread both. But I believe that only OpenMP is available in Cython as of today. Intel has worked on that : http://conference.scipy.org/proceedings/scipy2016/pdfs/anton_malakhov.pdf . It might be useful for Scikit-Learn.

These are my 2 cents on this implementation. I can help on implementing something without the (a-b)^2 trick and without the call to the BLAS. It should make things faster and make it way more efficient to parallelize the different initializations. But I would really need memoryview. I don't understand your problem regarding that. Can someone give me some pointers on the problems it raises?

François

@jnothman
Copy link
Member
jnothman commented Mar 3, 2018 via email

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 4, 2018

Hi Joel. In my first post on this thread, the benchmarks needs the knowledge of the number of features at compile time. But on my second post, the number of features is not known at compile time and we still get large speedups.
On my last post, I explain what I feel wrong in both performance and numerical stability on the current implementation.

PS1: Note that I have managed to do a few changes (register blocking) and made K-means clustering with 3 features (known at compile time) 4200 times faster than Scikit-Learn on the Dual-Xeon Skylake with 18 cores on each CPU. That begins to be quite a gap. I think we can manage to transfer most of this speed to Cython.
PS2: On the Scikit-Learn documentation for color quantization, one can see here that even the test code is "cheating" and taking the first 1000 pixels of the image. Because if you take the full image with is about 450 x 640 and which is not that big, it takes half an hour to complete. The exact same computation can be done in half a second.
Ref: http://scikit-learn.org/stable/auto_examples/cluster/plot_color_quantization.html

@jnothman
Copy link
Member
jnothman commented Mar 4, 2018

Thanks for the summary. Very interesting analysis. The Euclidean computation trick has bitten us elsewhere. I'm +1 for the improvement to that and dot at a minimum.

Our docs aren't cheating so much as trying to stay fast to build.

@rth
Copy link
Member
rth commented Mar 6, 2018

Thanks for this in depth analysis @FrancoisFayard ! It's valuable to have an new look on this,

Below are a few comments and questions, ( cc @massich )

I don't know how you handle multithreading in Scikit-Learn (apart from using a multithreaded BLAS)

Outside of Cython we use joblib.Parallel's threading backend.

Moreover, Intel MKL support both OpenMP and TBB. I don't know if this is an option you have considered.

Regarding TBB, not, as far as Github issue history goes. BTW, #9429 might be somewhat related to this issue in general. Also note that, scikit-learn is also frequently used in combination with OpenBLAS and I'm not sure how much TBB is used in that context (e.g. OpenMathLib/OpenBLAS#1282).

Concerning memoryviews, it is extremely important if you want vectorization. With Numpy, the problem is that the strides can be whatever you want. [..] Can someone give me some pointers on the problems it raises?

The issue was that until recently Cython did not support read-only memoryviews (cf #10624 (comment)). This is currently necessary for parallel computations with joblib. Therefore most of scikit-learn cython code currently uses the array buffer interface. It also allows to force array contiguity and alignment (e.g. exemple here), but from what I understand, indexing does call the Python C API, one can't release the GIL and there is some additional performance cost.

In the upcoming Cython 0.28 there will be support for read-only memory views, but it remains to be seen how well this will work with the existing code base cython/cython#1869 (comment) (hopefully it will).

This is a BLAS level 1 operation that compute the scalar product of 2 vectors. Those operations are bandwidth bound and throwing many threads at this task will make things slower than just using one thread. I am sure the MKL has som 8000 e mechanism to limit this kind of problem, but when you call dot, you are really asking for some parallelism.

Interesting. If that's true in general, why does MKL / OpenBLAS multi-threads those? (I imagine MKL_DYNAMIC=true might have an impact.)

InsideLoop - Gcc 7.3: x 43

How did you install the default scikit-learn version ("Scikit-Learn - Vanilla" in your benchmarks)?

Bear in mind that for portability reasons binary distributions of scikit-learn may have been compiled with intentionally outdated versions of gcc. For instance see manylinux1 policy that produce the binary wheels uploaded to PyPi uses gcc 4.x . This can have some impact I imagine (e.g. squeaky-pl/japronto#52). Similarly until recently conda used gcc 4.x. As far as I could see, on Travis CI used to build wheels the default compilation flag is -03 I'm not sure if without any -mtune compile flags this will produce e.g. AVX2 instructions?

I can help on implementing something without the (a-b)^2 trick and without the call to the BLAS.

As Joel said, your help would be appreciated. I don't know if it's possible to get some those speedup with the Cython buffer interface instead of arrays for now.

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 7, 2018

Thanks for your reply @rth

  • Parallelism: Multithreading is a major pain today for many reasons. Today's standard is OpenMP as it has wide adoption in the HPC community, but I believe that OpenMP is something from the past. Although it is still the best performer when the workload is well balanced, it suffers from the fact that it does not support nested parallelism: you can not create a parallel region within a parallel region. TBB, which has been developped by Intel for the last 10 years is better designed and allows that kind of thing. It is OpenSource and works on many hardware, including ARM and is not specific to Intel. Except some very specific libraries, I don't know any competitor.
    The problem with multithreading is that when you begin to mix libraries, it works, but you get suboptimal performance. That's why one of the nice features of the MKL is that it comes with an OpenMP version and a TBB version. Most other libraries such as OpenBLAS, Eigen only provide OpenMP versions.
    As of today, Intel as plugged TBB into joblib in their Intel Python Distribution, but TBB is not available to Cython. It might change though.

  • MemoryView: Maybe the best thing would be to wait before const memory view are available. I don't mind working on something that is not implemented in the next few months.

  • BLAS: All BLAS level 1 (dot product of two vectors), and BLAS level 2 (matrix-vector product) are memory bound. When I have said that it is useless to throw many threads at them, it is not completely true. If you make the dot product of two vectors of length 1'000'000, then it is better to send many threads at this work (but usually useless to throw too many of them). But here, the length of the vectors is the number of features and I believe that most of the time, this is below 100. In that case, one thread is way enough.
    Moreover, although the MKL is the fastest library on Intel architectures for large BLAS operations (think of sizes larger than 100), it is not optimal for small BLAS operations. There was many work in the past to make it faster, but this is a very difficult task and I believe that code must be generated for the given size to be optimal. Usually, it not known at compile time and therefore you need a JIT. The library libxsmm ( https://github.com/hfp/libxsmm ) is an open source library by Intel attempting to solve this problem.
    Finally one of the advantage of the MKL over other BLAS is that it has CPU dispatch: at run time the library dispatch the caller of the blas library to the right function that has been optimized for the generation of your CPU (SSE, AVX, AVX2, AVX512, etc). No other blas provide this functionnality and this is a major problem if you want to get the best of your chip. For instance TensorFlow, which is using Eigen and not the MKL by default, has just said that they stop releasing binaries compiled for SSE and that they'll use AVX in the binary release of TensorFlow from now own. But if you have AVX2 and AVX512 CPUs, they are not going to be exploited correctly.

This is to say that I don't think that using an old version of gcc (I have seen worse than gcc 4.x) is not a major problem. The bigger problem would be to generate a single binary from your Cython code that works optimally on all Intel architectures wether they are SSE, AVX, AVX2 or AVX512. There are some solutions to do that in C/C++, but it would need to be integrated into Cython if you want to get it from Cython.

  • (a-b)^2 trick: I gave some more thought about it. It is still clear that for numerical stability, it is better to compute (a-b)^2 than a^2 + b^2 - 2 a b. But this is true that for a large number of features, you can reduce the number of flops by 33% using this trick. But it is also important to notice that when you compute a b, you are computing a dot product and modern processors have FMA instructions to make this things faster. To keep things simple, if you compute Sum (x_k - y_k)^2, you cannot expect the number of floating point instructions per unit time as if you compute Sum x_k . y_k. On most Intel hardware, you'll get 2 times more flops with the second operations. So theoretically, one can expect a speedup of 3 ((3/2) for the number of operations, 2 for the usage of FMA) from using the (a-b)^2 trick. But as I said, you are sometimes at a factor of 4000 below the roofline, so I don't think a theoritical factor of 3 is that important today.

I'll try to work and optimize my C++ code to see how it compares against the DAAL from Intel and I'll then try to make it as fast using Cython and memoryviews. It might take some time though as I am quite busy right now. I'll come back to you and I'll even try to give a talk at INRIA about my findings.

@amueller
Copy link
Member
amueller commented Mar 9, 2018

Thank you for a great analysis.

Not sure if that has been addressed elsewhere but

But if you think more out of the box, the best loop for multithreading is the loop that starts the kmeans clustering with different initializations

That's exactly what we do if n_jobs > 1.

I think the (a-b)^2 trick is so that we can make a blas call which can use multithreading / fast blas.
This is a slow-down for 3dims but I'm pretty sure it'll be a speed-up for 1024 dims.
It would be interesting to see whether how much faster/slower not doing the trick is for high dimensional spaces.
A big reason to do this trick is probably that it is the "only way" to access multithreading right now, which is something that we need to fix. I'm not sure how the work on loky fits in here, and maybe @ogrisel can comment on it. This is definitely not an area where I'm an expert, but we do need a solution.

@amueller
Copy link
Member
amueller commented Mar 9, 2018

We can only consider TBB if it works well with openblas and if we can access it from cython.

@lesteve
Copy link
Member
lesteve commented Mar 10, 2018

I'm not sure how the work on loky fits in here, and maybe @ogrisel can comment on it.

I think what loky allows here (#7650 has more details). is to use OpenMP through cython while avoiding freezes due to interaction of Python multiprocessing with OpenMP (to sum up bad interaction of fork witout exec with libraries maintaining their own thread-pool, see this for more details). Note that loky is only part of the joblib development version and so is not bundled in scikit-learn at the time of writing. There has been quite a few changes in joblib and it may take a little bit of time until the next joblib release.

While I am here a comment more to the point of this issue: until now it seems like the focus if how fast can we make KMeans in an ideal world. I think it would be good at one point to also look at how this kind of ideas can be integrated into the existing scikit-learn code and have some benchmarks of scikit-learn PR vs scikit-learn master. There is a benchmark in benchmarks/bench_plot_fastkmeans.py, maybe some of it can be reused.

@FrancoisFayard
Copy link
Author
FrancoisFayard commented Mar 10, 2018

@amueller: I agree that the (a-b)^2 trick is a nice way to use a BLAS. But it can be better used. I can propose a replacement that could be way better, especially when the number of features is large.
Let A be the nb_points x nb_features matrix containing the coordinates of the points and B be the nb_features x nb_clusters matrix containing the coordinates of the centroids. The matrix product A.B contains all the scalar product you are looking for. Moreover here, there is quite some work to do and the CPU will be happy launching threads. If all those numbers are big enough, you'll get a lot of parallelism and when you call a level 3 BLAS function, you know that you'll get huge speedups over any other implementation.
There is still a problem which is memory, because we are going to create a matrix of size nb_points x nb_clusters which might be huge. To solve this problem, all you need is to handle the points by batches (maybe chunks of size nb_features). That way, one should get something close to the peak performance in case the number of features is large. I am very busy these days, but I'll come back to you when I have some benchmarks.
By the way, any expert in Machine Learning who could give us an idea on the distribution of nb_features used by Scikit-Learn users? Even a guess would be nice.

Concerning TBB, I have talked to some Intel people, and they might consider giving support to it in Cython if there is a requirement. But, I don't think you could get TBB in OpenBLAS unless you do it yourself. Still it is possible to mix multithreading librairies and call an OpenMP-multithreaded library from a TBB-multithreaded library. It is not ideal, but this can be done.

@jeremiedbb
Copy link
Member
jeremiedbb commented Jun 27, 2018

@FrancoisFayard Hi, I don't know if you are still working on this but if you have some time I've got a few questions about your benchmark. I'm trying to reproduce your benchmarks (with unknown n_dim at compile time) but I'm not even getting close to yours.
Here is my setup.py for icc compilation ('-O3' is used by default):

extensions = [Extension("kmeans",
                        sources=["kmeans.pyx"],
                        extra_compile_args=['-qopenmp', '-march=core_avx2'],
                        extra_link_args=['-qopenmp'],
                        include_dirs=[numpy.get_include()])]

setup(
    ext_modules = cythonize(extensions, annotate=True)
)

Could you tell me which flags did you use to compile with icc ? Also did you try to compile with gcc on your laptop? and finally, can you give me more precise infos of your laptop cpu ? are you close to the peak ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants
0