8000 Merge pull request #29235 from charris/backport-29223 · numpy/numpy@0481076 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0481076

Browse files
authored
Merge pull request #29235 from charris/backport-29223
BUG: Address interaction between SME and FPSR (#29223)
2 parents f37c80e + ad6d919 commit 0481076

File tree

8 files changed

+217
-8
lines changed

8 files changed

+217
-8
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,7 @@ src_multiarray_umath_common = [
11171117
]
11181118
if have_blas
11191119
src_multiarray_umath_common += [
1120+
'src/common/blas_utils.c',
11201121
'src/common/cblasfuncs.c',
11211122
'src/common/python_xerbla.c',
11221123
]

numpy/_core/src/common/blas_utils.c

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include <stdbool.h>
8000 2+
#include <stdio.h>
3+
#include <stdlib.h>
4+
5+
#ifdef __APPLE__
6+
#include <sys/sysctl.h>
7+
#endif
8+
9+
#include "numpy/numpyconfig.h" // NPY_VISIBILITY_HIDDEN
10+
#include "numpy/npy_math.h" // npy_get_floatstatus_barrier
11+
#include "blas_utils.h"
12+
13+
#if NPY_BLAS_CHECK_FPE_SUPPORT
14+
15+
/* Return whether we're running on macOS 15.4 or later
16+
*/
17+
static inline bool
18+
is_macOS_version_15_4_or_later(void){
19+
#if !defined(__APPLE__)
20+
return false;
21+
#else
22+
char *osProductVersion = NULL;
23+
size_t size = 0;
24+
bool ret = false;
25+
26+
// Query how large OS version string should be
27+
if(-1 == sysctlbyname("kern.osproductversion", NULL, &size, NULL, 0)){
28+
goto cleanup;
29+
}
30+
31+
osProductVersion = malloc(size + 1);
32+
33+
// Get the OS version string
34+
if(-1 == sysctlbyname("kern.osproductversion", osProductVersion, &size, NULL, 0)){
35+
goto cleanup;
36+
}
37+
38+
osProductVersion[size] = '\0';
39+
40+
// Parse the version string
41+
int major = 0, minor = 0;
42+
if(2 > sscanf(osProductVersion, "%d.%d", &major, &minor)) {
43 9E88 +
goto cleanup;
44+
}
45+
46+
if(major > 15 || (major == 15 && minor >= 4)) {
47+
ret = true;
48+
}
49+
50+
cleanup:
51+
if(osProductVersion){
52+
free(osProductVersion);
53+
}
54+
55+
return ret;
56+
#endif
57+
}
58+
59+
/* ARM Scalable Matrix Extension (SME) raises all floating-point error flags
60+
* when it's used regardless of values or operations. As a consequence,
61+
* when SME is used, all FPE state is lost and special handling is needed.
62+
*
63+
* For NumPy, SME is not currently used directly, but can be used via
64+
* BLAS / LAPACK libraries. This function does a runtime check for whether
65+
* BLAS / LAPACK can use SME and special handling around FPE is required.
66+
*/
67+
static inline bool
68+
BLAS_can_use_ARM_SME(void)
69+
{
70+
#if defined(__APPLE__) && defined(__aarch64__) && defined(ACCELERATE_NEW_LAPACK)
71+
// ARM SME can be used by Apple's Accelerate framework for BLAS / LAPACK
72+
// - macOS 15.4+
73+
// - Apple silicon M4+
74+
75+
// Does OS / Accelerate support ARM SME?
76+
if(!is_macOS_version_15_4_or_later()){
77+
return false;
78+
}
79+
80+
// Does hardware support SME?
81+
int has_SME = 0;
82+
size_t size = sizeof(has_SME);
83+
if(-1 == sysctlbyname("hw.optional.arm.FEAT_SME", &has_SME, &size, NULL, 0)){
84+
return false;
85+
}
86+
87+
if(has_SME){
88+
return true;
89+
}
90+
#endif
91+
92+
// default assume SME is not used
93+
return false;
94+
}
95+
96+
/* Static variable to cache runtime check of BLAS FPE support.
97+
*/
98+
static bool blas_supports_fpe = true;
99+
100+
#endif // NPY_BLAS_CHECK_FPE_SUPPORT
101+
102+
103+
NPY_VISIBILITY_HIDDEN bool
104+
npy_blas_supports_fpe(void)
105+
{
106+
#if NPY_BLAS_CHECK_FPE_SUPPORT
107+
return blas_supports_fpe;
108+
#else
109+
return true;
110+
#endif
111+
}
112+
113+
NPY_VISIBILITY_HIDDEN void
114+
npy_blas_init(void)
115+
{
116+
#if NPY_BLAS_CHECK_FPE_SUPPORT
117+
blas_supports_fpe = !BLAS_can_use_ARM_SME();
118+
#endif
119+
}
120+
121+
NPY_VISIBILITY_HIDDEN int
122+
npy_get_floatstatus_after_blas(void)
123+
{
124+
#if NPY_BLAS_CHECK_FPE_SUPPORT
125+
if(!blas_supports_fpe){
126+
// BLAS does not support FPE and we need to return FPE state.
127+
// Instead of clearing and then grabbing state, just return
128+
// that no flags are set.
129+
return 0;
130+
}
131+
#endif
132+
char *param = NULL;
133+
return npy_get_floatstatus_barrier(param);
134+
}

numpy/_core/src/common/blas_utils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <stdbool.h>
2+
3+
#include "numpy/numpyconfig.h" // for NPY_VISIBILITY_HIDDEN
4+
5+
/* NPY_BLAS_CHECK_FPE_SUPPORT controls whether we need a runtime check
6+
* for floating-point error (FPE) support in BLAS.
7+
*/
8+
#if defined(__APPLE__) && defined(__aarch64__) && defined(ACCELERATE_NEW_LAPACK)
9+
#define NPY_BLAS_CHECK_FPE_SUPPORT 1
10+
#else
11+
#define NPY_BLAS_CHECK_FPE_SUPPORT 0
12+
#endif
13+
14+
/* Initialize BLAS environment, if needed
15+
*/
16+
NPY_VISIBILITY_HIDDEN void
17+
npy_blas_init(void);
18+
19+
/* Runtime check if BLAS supports floating-point errors.
20+
* true - BLAS supports FPE and one can rely on them to indicate errors
21+
* false - BLAS does not support FPE. Special F987 handling needed for FPE state
22+
*/
23+
NPY_VISIBILITY_HIDDEN bool
24+
npy_blas_supports_fpe(void);
25+
26+
/* If BLAS supports FPE, exactly the same as npy_get_floatstatus_barrier().
27+
* Otherwise, we can't rely on FPE state and need special handling.
28+
*/
29+
NPY_VISIBILITY_HIDDEN int
30+
npy_get_floatstatus_after_blas(void);

numpy/_core/src/common/cblasfuncs.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "numpy/arrayobject.h"
1313
#include "numpy/npy_math.h"
1414
#include "numpy/ufuncobject.h"
15+
#include "blas_utils.h"
1516
#include "npy_cblas.h"
1617
#include "arraytypes.h"
1718
#include "common.h"
@@ -693,7 +694,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
693694
NPY_END_ALLOW_THREADS;
694695
}
695696

696-
int fpes = npy_get_floatstatus_barrier((char *) result);
697+
int fpes = npy_get_floatstatus_after_blas();
697698
if (fpes && PyUFunc_GiveFloatingpointErrors("dot", fpes) < 0) {
698699
goto fail;
699700
}

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
4343
#include "arraytypes.h"
4444
#include "arrayobject.h"
4545
#include "array_converter.h"
46+
#include "blas_utils.h"
4647
#include "hashdescr.h"
4748
#include "descriptor.h"
4849
#include "dragon4.h"
@@ -4805,6 +4806,10 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
48054806
goto err;
48064807
}
48074808

4809+
#if NPY_BLAS_CHECK_FPE_SUPPORT
4810+
npy_blas_init();
4811+
#endif
4812+
48084813
#if defined(MS_WIN64) && defined(__GNUC__)
48094814
PyErr_WarnEx(PyExc_Warning,
48104815
"Numpy built with MINGW-W64 on Windows 64 bits is experimental, " \

numpy/_core/src/umath/matmul.c.src

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818

19+
#include "blas_utils.h"
1920
#include "npy_cblas.h"
2021
#include "arraytypes.h" /* For TYPE_dot functions */
2122

@@ -120,7 +121,7 @@ static inline void
120121
}
121122
}
122123

123-
NPY_NO_EXPORT void
124+
static void
124125
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
125126
void *ip2, npy_intp is2_n,
126127
void *op, npy_intp op_m,
@@ -156,7 +157,7 @@ NPY_NO_EXPORT void
156157
is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
157158
}
158159

159-
NPY_NO_EXPORT void
160+
static void
160161
@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
161162
void *ip2, npy_intp is2_n, npy_intp is2_p,
162163
void *op, npy_intp os_m, npy_intp os_p,
@@ -260,7 +261,7 @@ NPY_NO_EXPORT void
260261
* #IS_HALF = 0, 0, 0, 1, 0*13#
261262
*/
262263

263-
NPY_NO_EXPORT void
264+
static void
264265
@TYPE@_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
265266
void *_ip2, npy_intp is2_n, npy_intp is2_p,
266267
void *_op, npy_intp os_m, npy_intp os_p,
@@ -318,7 +319,7 @@ NPY_NO_EXPORT void
318319
}
319320

320321
/**end repeat**/
321-
NPY_NO_EXPORT void
322+
static void
322323
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
323324
void *_ip2, npy_intp is2_n, npy_intp is2_p,
324325
void *_op, npy_intp os_m, npy_intp os_p,
@@ -357,7 +358,7 @@ BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
357358
}
358359
}
359360

360-
NPY_NO_EXPORT void
361+
static void
361362
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
362363
void *_ip2, npy_intp is2_n, npy_intp is2_p,
363364
void *_op, npy_intp os_m, npy_intp os_p,
@@ -629,6 +630,11 @@ NPY_NO_EXPORT void
629630
#endif
630631
}
631632
#if @USEBLAS@ && defined(HAVE_CBLAS)
633+
#if NPY_BLAS_CHECK_FPE_SUPPORT
634+
if (!npy_blas_supports_fpe()) {
635+
npy_clear_floatstatus_barrier((char*)args);
636+
}
637+
#endif
632638
if (allocate_buffer) free(tmp_ip12op);
633639
#endif
634640
}
@@ -653,7 +659,7 @@ NPY_NO_EXPORT void
653659
* #prefix = c, z, 0#
654660
* #USE_BLAS = 1, 1, 0#
655661
*/
656-
NPY_NO_EXPORT void
662+
static void
657663
@name@_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
658664
char *op, npy_intp n, void *NPY_UNUSED(ignore))
659665
{
@@ -749,6 +755,7 @@ OBJECT_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp
749755
* CFLOAT, CDOUBLE, CLONGDOUBLE, OBJECT#
750756
* #DOT = dot*15, dotc*4#
751757
* #CHECK_PYERR = 0*18, 1#
758+
* #CHECK_BLAS = 1*2, 0*13, 1*2, 0*2#
752759
*/
753760
NPY_NO_EXPORT void
754761
@TYPE@_vecdot(char **args, npy_intp const *dimensions, npy_intp const *steps,
@@ -772,6 +779,11 @@ NPY_NO_EXPORT void
772779
}
773780
#endif
774781
}
782+
#if @CHECK_BLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
783+
if (!npy_blas_supports_fpe()) {
784+
npy_clear_floatstatus_barrier((char*)args);
785+
}
786+
#endif
775787
}
776788
/**end repeat**/
777789

@@ -787,7 +799,7 @@ NPY_NO_EXPORT void
787799
* #step1 = &oneF, &oneD#
788800
* #step0 = &zeroF, &zeroD#
789801
*/
790-
NPY_NO_EXPORT void
802+
static void
791803
@name@_vecmat_via_gemm(void *ip1, npy_intp is1_n,
792804
void *ip2, npy_intp is2_n, npy_intp is2_m,
793805
void *op, npy_intp os_m,
@@ -878,6 +890,11 @@ NPY_NO_EXPORT void
878890
#endif
879891
}
880892
}
893+
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
894+
if (!npy_blas_supports_fpe()) {
895+
npy_clear_floatstatus_barrier((char*)args);
896+
}
897+
#endif
881898
}
882899
/**end repeat**/
883900

@@ -943,5 +960,10 @@ NPY_NO_EXPORT void
943960
#endif
944961
}
945962
}
963+
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
964+
if (!npy_blas_supports_fpe()) {
965+
npy_clear_floatstatus_barrier((char*)args);
966+
}
967+
#endif
946968
}
947969
/**end repeat**/

numpy/_core/tests/test_multiarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from numpy.exceptions import AxisError, ComplexWarning
3232
from numpy.lib.recfunctions import repack_fields
3333
from numpy.testing import (
34+
BLAS_SUPPORTS_FPE,
3435
HAS_REFCOUNT,
3536
IS_64BIT,
3637
IS_PYPY,
@@ -3363,6 +3364,11 @@ def test_dot(self):
33633364
@pytest.mark.parametrize("dtype", [np.half, np.double, np.longdouble])
33643365
@pytest.mark.skipif(IS_WASM, reason="no wasm fp exception support")
33653366
def test_dot_errstate(self, dtype):
3367+
# Some dtypes use BLAS for 'dot' operation and
3368+
# not all BLAS support floating-point errors.
3369+
if not BLAS_SUPPORTS_FPE and dtype == np.double:
3370+
pytest.skip("BLAS does not support FPE")
3371+
33663372
a = np.array([1, 1], dtype=dtype)
33673373
b = np.array([-np.inf, np.inf], dtype=dtype)
33683374

numpy/testing/_private/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
4343
'IS_MUSL', 'check_support_sve', 'NOGIL_BUILD',
4444
'IS_EDITABLE', 'IS_INSTALLED', 'NUMPY_ROOT', 'run_threaded', 'IS_64BIT',
45+
'BLAS_SUPPORTS_FPE',
4546
]
4647

4748

@@ -89,6 +90,15 @@ class KnownFailureException(Exception):
8990
IS_PYPY = sys.implementation.name == 'pypy'
9091
IS_PYSTON = hasattr(sys, "pyston_version_info")
9192
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
93+
BLAS_SUPPORTS_FPE = True
94+
if platform.system() == 'Darwin' or platform.machine() == 'arm64':
95+
try:
96+
blas = np.__config__.CONFIG['Build Dependencies']['blas']
97+
if blas['name'] == 'accelerate':
98+
BLAS_SUPPORTS_FPE = False
99+
except KeyError:
100+
pass
101+
92102
HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
93103

94104
IS_MUSL = False

0 commit comments

Comments
 (0)
0