8000 Fix error when ssrc less than 0 · numpy/numpy@db6b36c · GitHub
[go: up one dir, main page]

Skip to content

Commit db6b36c

Browse files
committed
Fix error when ssrc less than 0
1 parent 4154cf1 commit db6b36c

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

numpy/_core/src/umath/loops_unary_fp.dispatch.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
#include "fast_loop_macros.h"
1212
#include "loops_utils.h"
1313
#include <hwy/highway.h>
14-
#include "hwy/print-inl.h"
1514
#include <hwy/aligned_allocator.h>
1615

16+
1717
namespace hn = hwy::HWY_NAMESPACE;
1818
HWY_BEFORE_NAMESPACE();
1919
template <typename T>
@@ -193,7 +193,16 @@ HWY_INLINE void Super(char** args,
193193
auto store_index = hn::Mul(hn::Iota(di, 0), hn::Set(di, sdst));
194194
size_t full = size & -hn::Lanes(d);
195195
size_t remainder = size - full;
196-
if (sdst == 1 && ssrc != 1) {
196+
if (ssrc < 0) {
197+
auto load_index_reverse =
198+
hn::Reverse(di, hn::Add(hn::Mul(hn::Iota(di, 0), hn::Set(di, -ssrc)), hn::Set(di, -ssrc)));
199+
for (size_t i = 0; i < full; i += hn::Lanes(d)) {
200+
const auto in = hn::GatherIndex(d, input_array + (i + hn::Lanes(d)) * ssrc,
201+
load_index_reverse);
202+
auto x = op(in);
203+
hn::ScatterIndex(x, d, output_array + i * sdst, store_index);
204+
}
205+
} else if (sdst == 1 && ssrc != 1) {
197206
for (size_t i = 0; i < full; i += hn::Lanes(d)) {
198207
const auto in = hn::GatherIndex(d, input_array + i * ssrc, load_index);
199208
auto x = op(in);

0 commit comments

Comments
 (0)
0