10000 Fix error when ssrc less than 0 · numpy/numpy@b33d8dd · GitHub
[go: up one dir, main page]

Skip to content

Commit b33d8dd

Browse files
committed
Fix error when ssrc less than 0
1 parent 4154cf1 commit b33d8dd

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

numpy/_core/src/umath/loops_unary_fp.dispatch.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
#include "fast_loop_macros.h"
1212
#include "loops_utils.h"
1313
#include <hwy/highway.h>
14-
#include "hwy/print-inl.h"
1514
#include <hwy/aligned_allocator.h>
1615

17-
namespace hn = hwy::HWY_NAMESPACE;
18-
HWY_BEFORE_NAMESPACE();
1916
template <typename T>
2017
struct OpRound {
2118
HWY_INLINE hn::VFromD<hn::ScalableTag<T>> operator()(
@@ -193,7 +190,16 @@ HWY_INLINE void Super(char** args,
193190
auto store_index = hn::Mul(hn::Iota(di, 0), hn::Set(di, sdst));
194191
size_t full = size & -hn::Lanes(d);
195192
size_t remainder = size - full;
196-
if (sdst == 1 && ssrc != 1) {
193+
if (ssrc < 0) {
194+
auto load_index_reverse =
195+
hn::Reverse(di, hn::Add(hn::Mul(hn::Iota(di, 0), hn::Set(di, -ssrc)), hn::Set(di, -ssrc)));
196+
for (size_t i = 0; i < full; i += hn::Lanes(d)) {
197+
const auto in = hn::GatherIndex(d, input_array + (i + hn::Lanes(d)) * ssrc,
198+
load_index_reverse);
199+
auto x = op(in);
200+
hn::ScatterIndex(x, d, output_array + i * sdst, store_index);
201+
}
202+
} else if (sdst == 1 && ssrc != 1) {
197203
for (size_t i = 0; i < full; i += hn::Lanes(d)) {
198204
const auto in = hn::GatherIndex(d, input_array + i * ssrc, load_index);
199205
auto x = op(in);

0 commit comments

Comments
 (0)
0