diff --git a/sklearn/__init__.py b/sklearn/__init__.py index a61a2afde8855..b61d4a865fc97 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -82,6 +82,7 @@ _distributor_init, # noqa: F401 ) from .base import clone + from .utils._openmp_helpers import set_openblas_openmp_callback from .utils._show_versions import show_versions __all__ = [ @@ -138,6 +139,12 @@ except ModuleNotFoundError: pass + # Register an OpenMP backend for OpenBLAS such that OpenBLAS uses the OpenMP + # thread pool which prevents active spin-waiting when making successive quick + # alternating calls to OpenBLAS and OpenMP. + # see https://github.com/OpenMathLib/OpenBLAS/issues/3187 for more details + set_openblas_openmp_callback() + def setup_module(module): """Fixture for the tests to assure globally controllable seeding of RNGs""" diff --git a/sklearn/utils/_openmp_helpers.pyx b/sklearn/utils/_openmp_helpers.pyx index 88dca51089c56..492d16bba20a3 100644 --- a/sklearn/utils/_openmp_helpers.pyx +++ b/sklearn/utils/_openmp_helpers.pyx @@ -1,5 +1,9 @@ import os from joblib import cpu_count +from cython.parallel import prange +from ctypes import addressof + +from.parallel import _get_threadpool_controller # Module level cache for cpu_count as we do not expect this to change during @@ -75,3 +79,52 @@ cpdef _openmp_effective_n_threads(n_threads=None, only_physical_cores=True): return max(1, max_n_threads + n_threads + 1) return n_threads + + +ctypedef void (*openblas_dojob_callback)(int, void*, int) noexcept nogil +ctypedef void (*openblas_threads_callback)(int, openblas_dojob_callback, int, size_t, void*, int) +ctypedef void (*openblas_set_threads_callback_function_type)(openblas_threads_callback) + + +# Callback for OpenBLAS to make it use an OpenMP backend, took from +# https://github.com/OpenMathLib/OpenBLAS/pull/4577#issue-2204960832 +cdef void openblas_openmp_callback( + int sync, + openblas_dojob_callback dojob, + int numjobs, + size_t jobdata_elsize, + void *jobdata, + int dojob_data +) noexcept nogil: + cdef int i + cdef void *element_adrr + + for i in prange(numjobs, nogil=True): + element_adrr = (((jobdata) + (i) * jobdata_elsize)) + dojob(i, element_adrr, dojob_data) + + +def set_openblas_openmp_callback(): + controller = _get_threadpool_controller() + openblas_controllers = controller.select(internal_api="openblas").lib_controllers + + cdef openblas_set_threads_callback_function_type f_ptr + + for ct in openblas_controllers: + if not hasattr(ct, "dynlib"): + # too old version of threadpoolctl + continue + + lib = ct.dynlib + + if not hasattr(lib, "openblas_set_threads_callback_function"): + # openblas_set_threads_callback_function is available since v0.3.28 + continue + + func = lib.openblas_set_threads_callback_function + + # cast to the correct function pointer type + f_ptr = (addressof(func))[0] + f_ptr(openblas_openmp_callback) + + print("OpenBLAS OpenMP backend callback set")