8000 PERF Force OpenBLAS to use an OpenMP backend by jeremiedbb · Pull Request #29403 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

PERF Force OpenBLAS to use an OpenMP backend #29403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
_distributor_init, # noqa: F401
)
from .base import clone
from .utils._openmp_helpers import set_openblas_openmp_callback
from .utils._show_versions import show_versions

__all__ = [
Expand Down Expand Up @@ -138,6 +139,12 @@
except ModuleNotFoundError:
pass

# Register an OpenMP backend for OpenBLAS such that OpenBLAS uses the OpenMP
# thread pool which prevents active spin-waiting when making successive quick
# alternating calls to OpenBLAS and OpenMP.
# see https://github.com/OpenMathLib/OpenBLAS/issues/3187 for more details
set_openblas_openmp_callback()


def setup_module(module):
"""Fixture for the tests to assure globally controllable seeding of RNGs"""
Expand Down
53 changes: 53 additions & 0 deletions sklearn/utils/_openmp_helpers.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
from joblib import cpu_count
from cython.parallel import prange
from ctypes import addressof

from.parallel import _get_threadpool_controller


# Module level cache for cpu_count as we do not expect this to change during
Expand Down Expand Up @@ -75,3 +79,52 @@ cpdef _openmp_effective_n_threads(n_threads=None, only_physical_cores=True):
return max(1, max_n_threads + n_threads + 1)

return n_threads


ctypedef void (*openblas_dojob_callback)(int, void*, int) noexcept nogil
ctypedef void (*openblas_threads_callback)(int, openblas_dojob_callback, int, size_t, void*, int)
ctypedef void (*openblas_set_threads_callback_function_type)(openblas_threads_callback)


# Callback for OpenBLAS to make it use an OpenMP backend, took from
# https://github.com/OpenMathLib/OpenBLAS/pull/4577#issue-2204960832
cdef void openblas_openmp_callback(
int sync,
openblas_dojob_callback dojob,
int numjobs,
size_t jobdata_elsize,
void *jobdata,
int dojob_data
) noexcept nogil:
cdef int i
cdef void *element_adrr

for i in prange(numjobs, nogil=True):
element_adrr = <void *>(((<char *>jobdata) + (<unsigned>i) * jobdata_elsize))
dojob(i, element_adrr, dojob_data)


def set_openblas_openmp_callback():
controller = _get_threadpool_controller()
openblas_controllers = controller.select(internal_api="openblas").lib_controllers

cdef openblas_set_threads_callback_function_type f_ptr

for ct in openblas_controllers:
if not hasattr(ct, "dynlib"):
# too old version of threadpoolctl
continue

lib = ct.dynlib

if not hasattr(lib, "openblas_set_threads_callback_function"):
# openblas_set_threads_callback_function is available since v0.3.28
continue

func = lib.openblas_set_threads_callback_function

# cast to the correct function pointer type
f_ptr = (<openblas_set_threads_callback_function_type*><size_t>addressof(func))[0]
f_ptr(openblas_openmp_callback)

print("OpenBLAS OpenMP backend callback set")
0