10000 Merge pull request #23351 from r-devulap/avx512fp16 · numpy/numpy@1d0edc1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d0edc1

Browse files
authored
Merge pull request #23351 from r-devulap/avx512fp16
ENH: Use AVX512-FP16 SVML content for float16 umath functions
2 parents 09eb0ce + bc27299 commit 1d0edc1

File tree

6 files changed

+86
-6
lines changed

6 files changed

+86
-6
lines changed

numpy/_core/meson.build

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,13 @@ src_umath = umath_gen_headers + [
11391139
# here. Note that this migration is desirable; we then get the performance
11401140
# benefits for all platforms rather than only for AVX512 on 64-bit Linux, and
11411141
# may be able to avoid the accuracy regressions in SVML.
1142+
#
1143+
CPU_FEATURES_NAMES = CPU_BASELINE_NAMES + CPU_DISPATCH_NAMES
1144+
svml_file_suffix = ['d_la', 's_la', 'd_ha', 's_la']
1145+
if CPU_FEATURES_NAMES.contains('AVX512_SPR')
1146+
svml_file_suffix += ['h_la']
1147+
endif
1148+
11421149
svml_objects = []
11431150
if use_svml
11441151
foreach svml_func : [
@@ -1153,7 +1160,7 @@ if use_svml
11531160
'pow', 'sin', 'sinh', 'tan',
11541161
'tanh'
11551162
]
1156-
foreach svml_sfx : ['d_la', 's_la', 'd_ha', 's_la']
1163+
foreach svml_sfx : svml_file_suffix
11571164
svml_objects += [
11581165
'src/umath/svml/linux/avx512/svml_z0_'+svml_func+'_'+svml_sfx+'.s'
11591166
]

numpy/_core/src/common/npy_svml.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,27 @@
1+
#if NPY_SIMD && defined(NPY_HAVE_AVX512_SPR) && defined(NPY_CAN_LINK_SVML)
2+
extern void __svml_exps32(const npy_half*, npy_half*, npy_intp);
3+
extern void __svml_exp2s32(const npy_half*, npy_half*, npy_intp);
4+
extern void __svml_logs32(const npy_half*, npy_half*, npy_intp);
5+
extern void __svml_log2s32(const npy_half*, npy_half*, npy_intp);
6+
extern void __svml_log10s32(const npy_half*, npy_half*, npy_intp);
7+
extern void __svml_expm1s32(const npy_half*, npy_half*, npy_intp);
8+
extern void __svml_log1ps32(const npy_half*, npy_half*, npy_intp);
9+
extern void __svml_cbrts32(const npy_half*, npy_half*, npy_intp);
10+
extern void __svml_sins32(const npy_half*, npy_half*, npy_intp);
11+
extern void __svml_coss32(const npy_half*, npy_half*, npy_intp);
12+
extern void __svml_tans32(const npy_half*, npy_half*, npy_intp);
13+
extern void __svml_asins32(const npy_half*, npy_half*, npy_intp);
14+
extern void __svml_acoss32(const npy_half*, npy_half*, npy_intp);
15+
extern void __svml_atans32(const npy_half*, npy_half*, npy_intp);
16+
extern void __svml_atan2s32(const npy_half*, npy_half*, npy_intp);
17+
extern void __svml_sinhs32(const npy_half*, npy_half*, npy_intp);
18+
extern void __svml_coshs32(const npy_half*, npy_half*, npy_intp);
19+
extern void __svml_tanhs32(const npy_half*, npy_half*, npy_intp);
20+
extern void __svml_asinhs32(const npy_half*, npy_half*, npy_intp);
21+
extern void __svml_acoshs32(const npy_half*, npy_half*, npy_intp);
22+
extern void __svml_atanhs32(const npy_half*, npy_half*, npy_intp);
23+
#endif
24+
125
#if NPY_SIMD && defined(NPY_HAVE_AVX512_SKX) && defined(NPY_CAN_LINK_SVML)
226
extern __m512 __svml_expf16(__m512 x);
327
extern __m512 __svml_exp2f16(__m512 x);

numpy/_core/src/umath/loops_umath_fp.dispatch.c.src

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*@targets
2-
** $maxopt baseline avx512_skx
2+
** $maxopt baseline avx512_skx avx512_spr
33
*/
44
#include "numpy/npy_math.h"
55
#include "simd/simd.h"
@@ -156,7 +156,8 @@ avx512_@func@_f16(const npy_half *src, npy_half *dst, npy_intp len)
156156
NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(HALF_@func@)
157157
(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(data))
158158
{
159-
#if NPY_SIMD && defined(NPY_HAVE_AVX512_SKX) && defined(NPY_CAN_LINK_SVML)
159+
#if defined(NPY_HAVE_AVX512_SPR) || defined(NPY_HAVE_AVX512_SKX)
160+
#if NPY_SIMD && defined(NPY_CAN_LINK_SVML)
160161
const npy_half *src = (npy_half*)args[0];
161162
npy_half *dst = (npy_half*)args[1];
162163
const int lsize = sizeof(src[0]);
@@ -166,10 +167,17 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(HALF_@func@)
166167
if (!is_mem_overlap(src, steps[0], dst, steps[1], len) &&
167168
(ssrc == 1) &&
168169
(sdst == 1)) {
170+
#if defined(NPY_HAVE_AVX512_SPR)
171+
__svml_@intrin@s32(src, dst, len);
172+
return;
173+
#endif
174+
#if defined(NPY_HAVE_AVX512_SKX)
169175
avx512_@intrin@_f16(src, dst, len);
170176
return;
171-
}
172177
#endif
178+
}
179+
#endif // NPY_SIMD && NPY_CAN_LINK_SVML
180+
#endif // SPR or SKX
173181
UNARY_LOOP {
174182
const npy_float in1 = npy_half_to_float(*(npy_half *)ip1);
175183
*((npy_half *)op1) = npy_float_to_half(npy_@intrin@f(in1));

numpy/_core/tests/test_umath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,7 @@ def test_tanh(self):
16671667
for dt in ['e', 'f', 'd']:
16681668
in_arr = np.array(in_, dtype=dt)
16691669
out_arr = np.array(out, dtype=dt)
1670-
assert_equal(np.tanh(in_arr), out_arr)
1670+
assert_array_max_ulp(np.tanh(in_arr), out_arr, 3)
16711671

16721672
def test_arcsinh(self):
16731673
in_ = [np.nan, -np.nan, np.inf, -np.inf]

numpy/_core/tests/test_umath_accuracy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
IS_AVX = __cpu_features__.get('AVX512F', False) or \
2020
(__cpu_features__.get('FMA3', False) and __cpu_features__.get('AVX2', False))
21+
22+
IS_AVX512FP16 = __cpu_features__.get('AVX512FP16', False)
23+
2124
# only run on linux with AVX, also avoid old glibc (numpy/numpy#20448).
2225
runtest = (sys.platform.startswith('linux')
2326
and IS_AVX and not _glibc_older_than("2.17"))
@@ -68,6 +71,8 @@ def test_validate_transcendentals(self):
6871
maxulperr = data_subset['ulperr'].max()
6972
assert_array_max_ulp(npfunc(inval), outval, maxulperr)
7073

74+
@pytest.mark.skipif(IS_AVX512FP16,
75+
reason = "SVML FP16 have slightly higher ULP errors")
7176
@pytest.mark.parametrize("ufunc", UNARY_OBJECT_UFUNCS)
7277
def test_validate_fp16_transcendentals(self, ufunc):
7378
with np.errstate(all='ignore'):
@@ -76,3 +81,39 @@ def test_validate_fp16_transcendentals(self, ufunc):
7681
datafp32 = datafp16.astype(np.float32)
7782
assert_array_max_ulp(ufunc(datafp16), ufunc(datafp32),
7883
maxulp=1, dtype=np.float16)
84+
85+
@pytest.mark.skipif(not IS_AVX512FP16,
86+
reason="lower ULP only apply for SVML FP16")
87+
def test_validate_svml_fp16(self):
88+
max_ulp_err = {
89+
"arccos": 2.54,
90+
"arccosh": 2.09,
91+
"arcsin": 3.06,
92+
"arcsinh": 1.51,
93+
"arctan": 2.61,
94+
"arctanh": 1.88,
95+
"cbrt": 1.57,
96+
"cos": 1.43,
97+
"cosh": 1.33,
98+
"exp2": 1.33,
99+
"exp": 1.27,
100+
"expm1": 0.53,
101+
"log": 1.80,
102+
"log10": 1.27,
103+
"log1p": 1.88,
104+
"log2": 1.80,
105+
"sin": 1.88,
106+
"sinh": 2.05,
107+
"tan": 2.26,
108+
"tanh": 3.00,
109+
}
110+
111+
with np.errstate(all='ignore'):
112+
arr = np.arange(65536, dtype=np.int16)
113+
datafp16 = np.frombuffer(arr.tobytes(), dtype=np.float16)
114+
datafp32 = datafp16.astype(np.float32)
115+
for func in max_ulp_err:
116+
ufunc = getattr(np, func)
117+
ulp = np.ceil(max_ulp_err[func])
118+
assert_array_max_ulp(ufunc(datafp16), ufunc(datafp32),
119+
maxulp=ulp, dtype=np.float16)

0 commit comments

Comments
 (0)
0