8000 rewrite loop rint · numpy/numpy@347c72b · GitHub
[go: up one dir, main page]

Skip to content

Commit 347c72b

Browse files
committed
rewrite loop rint
1 parent 6f6be04 commit 347c72b

File tree

4 files changed

+155
-8
lines changed

4 files changed

+155
-8
lines changed

numpy/_core/code_generators/generate_umath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def english_upper(s):
988988
docstrings.get('numpy._core.umath.rint'),
989989
None,
990990
TD('e', f='rint', astype={'e': 'f'}),
991-
TD(inexactvec, dispatch=[('loops_unary_fp', 'fd')]),
991+
TD(inexactvec, cfunc_alias='rint'),
992992
TD('fdg' + cmplx, f='rint'),
993993
TD(P, f='rint'),
994994
),

numpy/_core/meson.build

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ if use_highway
105105
highway_lib = static_library('highway',
106106
[
107107
# required for hwy::Abort symbol
108-
'src/highway/hwy/abort.cc'
108+
'src/highway/hwy/abort.cc',
109+
'src/highway/hwy/per_target.cc'
109110
],
110111
cpp_args: '-DTOOLCHAIN_MISS_ASM_HWCAP_H',
111112
include_directories: ['src/highway'],
@@ -1141,6 +1142,7 @@ src_umath = umath_gen_headers + [
11411142
src_file.process('src/umath/matmul.c.src'),
11421143
src_file.process('src/umath/matmul.h.src'),
11431144
'src/umath/ufunc_type_resolution.c',
1145+
'src/umath/loop_unary_fp.cpp',
11441146
'src/umath/clip.cpp',
11451147
'src/umath/clip.h',
11461148
'src/umath/dispatching.c',
@@ -1214,6 +1216,7 @@ py.extension_module('_multiarray_umath',
12141216
'src/multiarray',
12151217
'src/npymath',
12161218
'src/umath',
1219+
'src/highway',
12171220
],
12181221
dependencies: [blas_dep],
12191222
link_with: [npymath_lib, multiarray_umath_mtargets.static_lib('_multiarray_umath_mtargets')] + highway_lib,
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#define _UMATHMODULE
2+
#define _MULTIARRAYMODULE
3+
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
4+
5+
#define PY_SSIZE_T_CLEAN
6+
#include <Python.h>
7+
8+
#include "numpy/ndarraytypes.h"
9+
#include "numpy/npy_common.h"
10+
#include "numpy/npy_math.h"
11+
#include "numpy/utils.h"
12+
13+
#include "fast_loop_macros.h"
14+
#include "loops_utils.h"
15+
#undef HWY_TARGET_INCLUDE
16+
#define HWY_TARGET_INCLUDE "loop_unary_fp.cpp" // this file
17+
#include <hwy/foreach_target.h> // must come before highway.h
18+
#include <hwy/highway.h>
19+
#include <hwy/aligned_allocator.h>
20+
21+
22+
namespace numpy {
23+
namespace HWY_NAMESPACE { // required: unique per target
24+
25+
// Can skip hn:: prefixes if already inside hwy::HWY_NAMESPACE.
26+
namespace hn = hwy::HWY_NAMESPACE;
27+
28+
// Alternative to per-function HWY_ATTR: see HWY_BEFORE_NAMESPACE
29+
#define SUPER(NAME, FUNC) \
30+
template <typename T> \
31+
HWY_ATTR void Super##NAME(char** args, npy_intp const* dimensions, \
32+
npy_intp const* steps) { \
33+
const T* HWY_RESTRICT input_array = (const T*)args[0]; \
34+
T* HWY_RESTRICT output_array = (T*)args[1]; \
35+
const size_t size = dimensions[0]; \
36+
const hn::ScalableTag<T> d; \
37+
\
38+
if (is_mem_overlap(input_array, steps[0], output_array, steps[1], size)) { \
39+
for (size_t i = 0; i < size; i++) { \
40+
const auto in = hn::LoadN(d, input_array + i, 1); \
41+
auto x = FUNC(in); \
42+
hn::StoreN(x, d, output_array + i, 1); \
43+
} \
44+
} else if (IS_UNARY_CONT(input_array, output_array)) { \
45+
size_t full = size & -hn::Lanes(d); \
46+
size_t remainder = size - full; \
47+
if (full > hn::Lanes(d) * 4) { \
48+
for (size_t i = 0; hn::Lanes(d) * 4 <= full - i; \
49+
i += hn::Lanes(d) * 4) { \
50+
const auto in0 = hn::LoadU(d, input_array + i); \
51+
auto x0 = FUNC(in0); \
52+
\
53+
const auto in1 = hn::LoadU(d, input_array + i + hn::Lanes(d) * 1); \
54+
auto x1 = FUNC(in1); \
55+
\
56+
const auto in2 = hn::LoadU(d, input_array + i + hn::Lanes(d) * 2); \
57+
auto x2 = FUNC(in2); \
58+
\
59+
const auto in3 = hn::LoadU(d, input_array + i + hn::Lanes(d) * 3); \
60+
auto x3 = FUNC(in3); \
61+
\
62+
hn::StoreU(x0, d, output_array + i); \
63+
hn::StoreU(x1, d, output_array + i + hn::Lanes(d) * 1); \
64+
hn::StoreU(x2, d, output_array + i + hn::Lanes(d) * 2); \
65+
hn::StoreU(x3, d, output_array + i + hn::Lanes(d) * 3); \
66+
} \
67+
full = full % (hn::Lanes(d) * 4); \
68+
} \
69+
for (size_t i = 0; i < full; i += hn::Lanes(d)) { \
70+
const auto in = hn::LoadU(d, input_array + i); \
71+
auto x = FUNC(in); \
72+
hn::StoreU(x, d, output_array + i); \
73+
} \
74+
if (remainder) { \
75+
const auto in = hn::LoadN(d, input_array + full, remainder); \
76+
auto x = FUNC(in); \
77+
hn::StoreN(x, d, output_array + full, remainder); \
78+
} \
79+
} else { \
80+
using TI = hwy::MakeSigned<T>; \
81+
const hn::Rebind<TI, hn::ScalableTag<T>> di; \
82+
\
83+
const int lsize = sizeof(input_array[0]); \
84+
const npy_intp ssrc = steps[0] / lsize; \
85+
const npy_intp sdst = steps[1] / lsize; \
86+
auto load_index = hn::Mul(hn::Iota(di, 0), hn::Set(di, ssrc)); \
87+
auto store_index = hn::Mul(hn::Iota(di, 0), hn::Set(di, sdst)); \
88+
size_t full = size & -hn::Lanes(d); \
89+
size_t remainder = size - full; \
90+
for (size_t i = 0; i < full; i += hn::Lanes(d)) { \
91+
const auto in = \
92+
hn::GatherIndex(d, input_array + i * ssrc, load_index); \
93+
auto x = FUNC(in); \
94+
hn::ScatterIndex(x, d, output_array + i * sdst, store_index); \
95+
} \
96+
if (remainder) { \
97+
const auto in = hn::GatherIndexN(d, input_array + full * ssrc, \
98+
load_index, remainder); \
99+
auto x = FUNC(in); \
100+
hn::ScatterIndexN(x, d, output_array + full * sdst, store_index, \
101+
remainder); \
102+
} \
103+
} \
104+
}
105+
106+
SUPER(Rint, hn::Round)
107+
108+
HWY_ATTR void DOUBLE_HWRint(char **args, npy_intp const *dimensions, npy_intp const *steps) {
109+
SuperRint<npy_double>(args, dimensions, steps);
110+
}
111+
112+
HWY_ATTR void FLOAT_HWRint(char **args, npy_intp const *dimensions, npy_intp const *steps) {
113+
SuperRint<npy_float>(args, dimensions, steps);
114+
}
115+
116+
}
117+
}
118+
119+
#if HWY_ONCE
120+
namespace numpy {
121+
122+
HWY_EXPORT(FLOAT_HWRint);
123+
HWY_EXPORT(DOUBLE_HWRint);
124+
125+
extern "C" {
126+
127+
NPY_NO_EXPORT void
128+
DOUBLE_rint(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
129+
{
130+
auto dispatcher = HWY_DYNAMIC_POINTER(DOUBLE_HWRint);
131+
return dispatcher(args, dimensions, steps);
132+
}
133+
134+
NPY_NO_EXPORT void
135+
FLOAT_rint(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
136+
{
137+
auto dispatcher = HWY_DYNAMIC_POINTER(FLOAT_HWRint);
138+
return dispatcher(args, dimensions, steps);
139+
}
140+
141+
} // extern "C"
142+
} // numpy
143+
#endif
144+

numpy/_core/src/umath/loops_unary_fp.dispatch.c.src

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ NPY_FINLINE double c_square_f64(double a)
100100
*/
101101
#if @VCHK@
102102
/**begin repeat1
103-
* #kind = rint, floor, ceil, trunc, sqrt, absolute, square, reciprocal#
104-
* #intr = rint, floor, ceil, trunc, sqrt, abs, square, recip#
105-
* #repl_0w1 = 0*7, 1#
103+
* #kind = floor, ceil, trunc, sqrt, absolute, square, reciprocal#
104+
* #intr = floor, ceil, trunc, sqrt, abs, square, recip#
105+
* #repl_0w1 = 0*6, 1#
106106
*/
107107
/**begin repeat2
108108
* #STYPE = CONTIG, NCONTIG, CONTIG, NCONTIG#
@@ -199,9 +199,9 @@ static void simd_@TYPE@_@kind@_@STYPE@_@DTYPE@
199199
* #VCHK = NPY_SIMD_F32, NPY_SIMD_F64#
200200
*/
201201
/**begin repeat1
202-
* #kind = rint, floor, ceil, trunc, sqrt, absolute, square, reciprocal#
203-
* #intr = rint, floor, ceil, trunc, sqrt, abs, square, recip#
204-
* #clear = 0, 0, 0, 0, 0, 1, 0, 0#
202+
* #kind = floor, ceil, trunc, sqrt, absolute, square, reciprocal#
203+
* #intr = floor, ceil, trunc, sqrt, abs, square, recip#
204+
* #clear = 0, 0, 0, 0, 1, 0, 0#
205205
*/
206206
NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@kind@)
207207
(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))

0 commit comments

Comments
 (0)
0