8000 Merge pull request #7999 from juliantaylor/inplace-opt · numpy/numpy@8eedd3e · GitHub
[go: up one dir, main page]

Skip to content

Commit 8eedd3e

Browse files
authored
Merge pull request #7999 from juliantaylor/inplace-opt
ENH: add inplace cases to fast ufunc loop macros
2 parents a93d9f7 + d555a0a commit 8eedd3e

File tree

5 files changed

+103
-43
lines changed

5 files changed

+103
-43
lines changed

benchmarks/benchmarks/bench_ufunc.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,45 @@ def time_or_bool(self):
7777
(self.b | self.b)
7878

7979

80+
class CustomInplace(Benchmark):
81+
def setup(self):
82+
self.c = np.ones(500000, dtype=np.int8)
83+
self.i = np.ones(150000, dtype=np.int32)
84+
self.f = np.zeros(150000, dtype=np.float32)
85+
self.d = np.zeros(75000, dtype=np.float64)
86+
# fault memory
87+
self.f *= 1.
88+
self.d *= 1.
89+
90+
def time_char_or(self):
91+
np.bitwise_or(self.c, 0, out=self.c)
92+
np.bitwise_or(0, self.c, out=self.c)
93+
94+
def time_char_or_temp(self):
95+
0 | self.c | 0
96+
97+
def time_int_or(self):
98+
np.bitwise_or(self.i, 0, out=self.i)
99+
np.bitwise_or(0, self.i, out=self.i)
100+
101+
def time_int_or_temp(self):
102+
0 | self.i | 0
103+
104+
def time_float_add(self):
105+
np.add(self.f, 1., out=self.f)
106+
np.add(1., self.f, out=self.f)
107+
108+
def time_float_add_temp(self):
109+
1. + self.f + 1.
110+
111+
def time_double_add(self):
112+
np.add(self.d, 1., out=self.d)
113+
np.add(1., self.d, out=self.d)
114+
115+
def time_double_add_temp(self):
116+
1. + self.d + 1.
117+
118+
80119
class CustomScalar(Benchmark):
81120
params = [np.float32, np.float64]
82121
param_names = ['dtype']

numpy/core/src/umath/loops.c.src

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,25 @@
8787
* combine with NPY_GCC_OPT_3 to allow autovectorization
8888
* should only be used where its worthwhile to avoid code bloat
8989
*/
90+
#define BASE_UNARY_LOOP(tin, tout, op) \
91+
UNARY_LOOP { \
92+
const tin in = *(tin *)ip1; \
93+
tout * out = (tout *)op1; \
94+
op; \
95+
}
9096
#define UNARY_LOOP_FAST(tin, tout, op) \
9197
do { \
9298
/* condition allows compiler to optimize the generic macro */ \
9399
if (IS_UNARY_CONT(tin, tout)) { \
94-
UNARY_LOOP { \
95-
const tin in = *(tin *)ip1; \
96-
tout * out = (tout *)op1; \
97-
op; \
100+
if (args[0] == args[1]) { \
101+
BASE_UNARY_LOOP(tin, tout, op) \
102+
} \
103+
else { \
104+
BASE_UNARY_LOOP(tin, tout, op) \
98105
} \
99106
} \
100107
else { \
101-
UNARY_LOOP { \
102-
const tin in = *(tin *)ip1; \
103-
tout * out = (tout *)op1; \
104-
op; \
105-
} \
108+
BASE_UNARY_LOOP(tin, tout, op) \
106109
} \
107110
} \
108111
while (0)
@@ -128,40 +131,52 @@
128131
* combine with NPY_GCC_OPT_3 to allow autovectorization
129132
* should only be used where its worthwhile to avoid code bloat
130133
*/
134+
#define BASE_BINARY_LOOP(tin, tout, op) \
135+
BINARY_LOOP { \
136+
const tin in1 = *(tin *)ip1; \
137+
const tin in2 = *(tin *)ip2; \
138+
tout * out = (tout *)op1; \
139+
op; \
140+
}
141+
#define BASE_BINARY_LOOP_S(tin, tout, cin, cinp, vin, vinp, op) \
142+
const tin cin = *(tin *)cinp; \
143+
BINARY_LOOP { \
144+
const tin vin = *(tin *)vinp; \
145+
tout * out = (tout *)op1; \
146+
op; \
147+
}
131148
#define BINARY_LOOP_FAST(tin, tout, op) \
132149
do { \
133150
/* condition allows compiler to optimize the generic macro */ \
134151
if (IS_BINARY_CONT(tin, tout)) { \
135-
BINARY_LOOP { \
136-
const tin in1 = *(tin *)ip1; \
137-
const tin in2 = *(tin *)ip2; \
138-
tout * out = (tout *)op1; \
139-
op; \
152+
if (args[2] == args[0]) { \
153+
BASE_BINARY_LOOP(tin, tout, op) \
154+
} \
155+
else if (args[2] == args[1]) { \
156+
BASE_BINARY_LOOP(tin, tout, op) \
157+
} \
158+
else { \
159+
BASE_BINARY_LOOP(tin, tout, op) \
140160
} \
141161
} \
142162
else if (IS_BINARY_CONT_S1(tin, tout)) { \
143-
const tin in1 = *(tin *)args[0]; \
144-
BINARY_LOOP { \
145-
const tin in2 = *(tin *)ip2; \
146-
tout * out = (tout *)op1; \
147-
op; \
163+
if (args[1] == args[2]) { \
164+
BASE_BINARY_LOOP_S(tin, tout, in1, args[0], in2, ip2, op) \
165+
} \
166+
else { \
167+
BASE_BINARY_LOOP_S(tin, tout, in1, args[0], in2, ip2, op) \
148168
} \
149169
} \
150170
else if (IS_BINARY_CONT_S2(tin, tout)) { \
151-
const tin in2 = *(tin *)args[1]; \
152-
BINARY_LOOP { \
153-
const tin in1 = *(tin *)ip1; \
154-
tout * out = (tout *)op1; \
155-
op; \
171+
if (args[0] == args[2]) { \
172+
BASE_BINARY_LOOP_S(tin, tout, in2, args[1], in1, ip1, op) \
156173
} \
174+
else { \
175+
BASE_BINARY_LOOP_S(tin, tout, in2, args[1], in1, ip1, op) \
176+
}\
157177
} \
158178
else { \
159-
BINARY_LOOP { \
160-
const tin in1 = *(tin *)ip1; \
161-
const tin in2 = *(tin *)ip2; \
162-
tout * out = (tout *)op1; \
163-
op; \
164-
} \
179+
BASE_BINARY_LOOP(tin, tout, op) \
165180
} \
166181
} \
167182
while (0)

numpy/core/tests/test_scalarmath.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ class TestBaseMath(TestCase):
6565
def test_blocked(self):
6666
# test alignments offsets for simd instructions
6767
# alignments for vz + 2 * (vs - 1) + 1
68-
for dt, sz in [(np.float32, 11), (np.float64, 7)]:
68+
for dt, sz in [(np.float32, 11), (np.float64, 7), (np.int32, 11)]:
6969
for out, inp1, inp2, msg in _gen_alignment_data(dtype=dt,
7070
type='binary',
7171
max_size=sz):
7272
exp1 = np.ones_like(inp1)
7373
inp1[...] = np.ones_like(inp1)
7474
inp2[...] = np.zeros_like(inp2)
7575
assert_almost_equal(np.add(inp1, inp2), exp1, err_msg=msg)
76-
assert_almost_equal(np.add(inp1, 1), exp1 + 1, err_msg=msg)
76+
assert_almost_equal(np.add(inp1, 2), exp1 + 2, err_msg=msg)
7777
assert_almost_equal(np.add(1, inp2), exp1, err_msg=msg)
7878

7979
np.add(inp1, inp2, out=out)
@@ -82,15 +82,17 @@ def test_blocked(self):
8282
inp2[...] += np.arange(inp2.size, dtype=dt) + 1
8383
assert_almost_equal(np.square(inp2),
8484
np.multiply(inp2, inp2), err_msg=msg)
85-
assert_almost_equal(np.reciprocal(inp2),
86-
np.divide(1, inp2), err_msg=msg)
85+
# skip true divide for ints
86+
if dt != np.int32 or sys.version_info.major < 3:
87+
assert_almost_equal(np.reciprocal(inp2),
88+
np.divide(1, inp2), err_msg=msg)
8789

8890
inp1[...] = np.ones_like(inp1)
89-
inp2[...] = np.zeros_like(inp2)
90-
np.add(inp1, 1, out=out)
91-
assert_almost_equal(out, exp1 + 1, err_msg=msg)
92-
np.add(1, inp2, out=out)
93-
assert_almost_equal(out, exp1, err_msg=msg)
91+
np.add(inp1, 2, out=out)
92+
assert_almost_equal(out, exp1 + 2, err_msg=msg)
93+
inp2[...] = np.ones_like(inp2)
94+
np.add(2, inp2, out=out)
95+
assert_almost_equal(out, exp1 + 2, err_msg=msg)
9496

9597
def test_lower_align(self):
9698
# check data that is not aligned to element size

numpy/core/tests/test_umath.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,8 +1202,9 @@ def test_abs_neg_blocked(self):
12021202
assert_array_equal(out, d, err_msg=msg)
12031203

12041204
assert_array_equal(-inp, -1*inp, err_msg=msg)
1205+
d = -1 * inp
12051206
np.negative(inp, out=out)
1206-
assert_array_equal(out, -1*inp, err_msg=msg)
1207+
assert_array_equal(out, d, err_msg=msg)
12071208

12081209
def test_lower_align(self):
12091210
# check data that is not aligned to element size

numpy/testing/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,8 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
17941794
inp = lambda: arange(s, dtype=dtype)[o:]
17951795
out = empty((s,), dtype=dtype)[o:]
17961796
yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
1797-
yield inp(), inp(), ufmt % (o, o, s, dtype, 'in place')
1797+
d = inp()
1798+
yield d, d, ufmt % (o, o, s, dtype, 'in place')
17981799
yield out[1:], inp()[:-1], ufmt % \
17991800
(o + 1, o, s - 1, dtype, 'out of place')
18001801
yield out[:-1], inp()[1:], ufmt % \
@@ -1809,9 +1810,11 @@ def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
18091810
out = empty((s,), dtype=dtype)[o:]
18101811
yield out, inp1(), inp2(), bfmt % \
18111812
(o, o, o, s, dtype, 'out of place')
1812-
yield inp1(), inp1(), inp2(), bfmt % \
1813+
d = inp1()
1814+
yield d, d, inp2(), bfmt % \
18131815
(o, o, o, s, dtype, 'in place1')
1814-
yield inp2(), inp1(), inp2(), bfmt % \
1816+
d = inp2()
1817+
yield d, inp1(), d, bfmt % \
18151818
(o, o, o, s, dtype, 'in place2')
18161819
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
18171820
(o + 1, o, o, s - 1, dtype, 'out of place')

0 commit comments

Comments
 (0)
0