8000 Merge pull request #28294 from abhishek-iitmadras/abhishek_sve_tanh · numpy/numpy@74f6905 · GitHub
[go: up one dir, main page]

Skip to content

Commit 74f6905

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #28294 from abhishek-iitmadras/abhishek_sve_tanh
MAINT: Enable building tanh on vector length agnostic architectures
2 parents 4e37e74 + b46c458 commit 74f6905

File tree

1 file changed

+39
-42
lines changed

1 file changed

+39
-42
lines changed

numpy/_core/src/umath/loops_hyperbolic.dispatch.cpp.src

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ store_vector(vtype vec, type_t* dst, npy_intp sdst, npy_intp len){
152152
#if NPY_SIMD_F64
153153

154154
[[maybe_unused]] HWY_ATTR NPY_FINLINE vec_f64 lut_16_f64(const double * lut, vec_u64 idx){
155-
if constexpr(hn::Lanes(f64) == 8){
155+
if constexpr(hn::MaxLanes(f64) == 8){
156156
const vec_f64 lut0 = hn::Load(f64, lut);
157157
const vec_f64 lut1 = hn::Load(f64, lut + 8);
158158
return hn::TwoTablesLookupLanes(f64, lut0, lut1, hn::IndicesFromVec(f64, idx));
159-
}else if constexpr (hn::Lanes(f64) == 4){
159+
}else if constexpr (hn::MaxLanes(f64) == 4){
160160
const vec_f64 lut0 = hn::Load(f64, lut);
161161
const vec_f64 lut1 = hn::Load(f64, lut + 4);
162162
const vec_f64 lut2 = hn::Load(f64, lut + 8);
@@ -392,22 +392,19 @@ simd_tanh_f64(const double *src, npy_intp ssrc, double *dst, npy_intp sdst, npy_
392392
// implemented so we require `npyv_nlanes_f64` == 2.
393393
vec_f64 b, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16;
394394
if constexpr(hn::MaxLanes(f64) == 2){
395-
vec_f64 e0e1[hn::Lanes(f64)];
395+
vec_f64 e0e1_0, e0e1_1;
396396
uint64_t index[hn::Lanes(f64)];
397397
hn::StoreU(idx, u64, index);
398398

399399
/**begin repeat
400-
* #off= 0, 2, 4, 6, 8, 10, 12, 14, 16#
401-
* #e0 = b, c1, c3, c5, c7, c9, c11, c13, c15#
402-
* #e1 = c0,c2, c4, c6, c8, c10,c12, c14, c16#
403-
*/
404-
/**begin repeat1
405-
* #lane = 0, 1#
406-
*/
407-
e0e1[@lane@] = hn::LoadU(f64, (const double*)lut18x16 + index[@lane@] * 18 + @off@);
408-
/**end repeat1**/
409-
@e0@ = hn::ConcatLowerLower(f64, e0e1[1], e0e1[0]);
410-
@e1@ = hn::ConcatUpperUpper(f64, e0e1[1], e0e1[0]);
400+
* #off = 0, 2, 4, 6, 8, 10, 12, 14, 16#
401+
* #e0 = b, c1, c3, c5, c7, c9, c11,c13,c15#
402+
* #e1 = c0, c2, c4, c6, c8, c10,c12,c14,c16#
403+
*/
404+
e0e1_0 = hn::LoadU(f64, (const double*)lut18x16 + index[0] * 18 + @off@);
405+
e0e1_1 = hn::LoadU(f64, (const double*)lut18x16 + index[1] * 18 + @off@);
406+
@e0@ = hn::ConcatLowerLower(f64, e0e1_1, e0e1_0);
407+
@e1@ = hn::ConcatUpperUpper(f64, e0e1_1, e0e1_0);
411408
/**end repeat**/
412409
} else {
413410
b = lut_16_f64((const double*)lut16x18 + 16*0, idx);
@@ -464,23 +461,17 @@ simd_tanh_f64(const double *src, npy_intp ssrc, double *dst, npy_intp sdst, npy_
464461

465462
#if NPY_SIMD_F32
466463

467-
struct hwy_f32x2 {
468-
vec_f32 val[2];
469-
};
470-
471-
HWY_ATTR NPY_FINLINE hwy_f32x2 zip_f32(vec_f32 a, vec_f32 b){
472-
hwy_f32x2 res;
473-
res.val[0] = hn::InterleaveLower(f32, a, b);
474-
res.val[1] = hn::InterleaveUpper(f32, a, b);
475-
return res;
464+
HWY_ATTR NPY_FINLINE void zip_f32_lanes(vec_f32 a, vec_f32 b, vec_f32& lower, vec_f32& upper) {
465+
lower = hn::InterleaveLower(f32, a, b);
466+
upper = hn::InterleaveUpper(f32, a, b);
476467
}
477468

478469
[[maybe_unused]] HWY_ATTR NPY_FINLINE vec_f32 lut_32_f32(const float * lut, vec_u32 idx){
479-
if constexpr(hn::Lanes(f32) == 16){
470+
if constexpr(hn::MaxLanes(f32) == 16){
480471
const vec_f32 lut0 = hn::Load(f32, lut);
481472
const vec_f32 lut1 = hn::Load(f32, lut + 16);
482473
return hn::TwoTablesLookupLanes(f32, lut0, lut1, hn::IndicesFromVec(f32, idx));
483-
}else if constexpr (hn::Lanes(f32) == 8){
474+
}else if constexpr (hn::MaxLanes(f32) == 8){
484475
const vec_f32 lut0 = hn::Load(f32, lut);
485476
const vec_f32 lut1 = hn::Load(f32, lut + 8);
486477
const vec_f32 lut2 = hn::L 10000 oad(f32, lut + 16);
@@ -608,16 +599,16 @@ simd_tanh_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst, npy_in
608599
// supported so we require `npyv_nlanes_f32` == 4.
609600
vec_f32 b, c0, c1, c2, c3, c4, c5, c6;
610601
if constexpr(hn::MaxLanes(f32) == 4 && HWY_TARGET >= HWY_SSE4){
611-
vec_f32 c6543[npyv_nlanes_f32];
612-
vec_f32 c210b[npyv_nlanes_f32];
602+
vec_f32 c6543_0, c6543_1, c6543_2, c6543_3;
603+
vec_f32 c210b_0, c210b_1, c210b_2, c210b_3;
613604
npyv_lanetype_u32 index[npyv_nlanes_f32];
614605

615606
/**begin repeat
616607
* #lane = 0, 1, 2, 3#
617608
*/
618609
index[@lane@] = hn::ExtractLane(idx, @lane@);
619-
c6543[@lane@] = hn::LoadU(f32, (const float*)lut8x32 + index[@lane@] * 8);
620-
c210b[@lane@] = hn::LoadU(f32, (const float*)lut8x32 + index[@lane@] * 8 + 4);
610+
c6543_@lane@ = hn::LoadU(f32, (const float*)lut8x32 + index[@lane@] * 8);
611+
c210b_@lane@ = hn::LoadU(f32, (const float*)lut8x32 + index[@lane@] * 8 + 4);
621612
/**end repeat**/
622613

623614
// lane0: {c6, c5, c4, c3}, {c2, c1, c0, b}
@@ -635,19 +626,25 @@ simd_tanh_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst, npy_in
635626
// c0: {lane0, lane1, lane2, lane3}
636627
// b : {lane0, lane1, lane2, lane3}
637628

638-
hwy_f32x2 c6543_l01 = zip_f32(c6543[0], c6543[1]);
639-
hwy_f32x2 c6543_l23 = zip_f32(c6543[2], c6543[3]);
640-
c6 = hn::ConcatLowerLower(f32, c6543_l23.val[0], c6543_l01.val[0]);
641-
c5 = hn::ConcatUpperUpper(f32, c6543_l23.val[0], c6543_l01.val[0]);
642-
c4 = hn::ConcatLowerLower(f32, c6543_l23.val[1], c6543_l01.val[1]);
643-
c3 = hn::ConcatUpperUpper(f32, c6543_l23.val[1], c6543_l01.val[1]);
644-
645-
hwy_f32x2 c210b_l01 = zip_f32(c210b[0], c210b[1]);
646-
hwy_f32x2 c210b_l23 = zip_f32(c210b[2], c210b[3]);
647-
c2 = hn::ConcatLowerLower(f32, c210b_l23.val[0], c210b_l01.val[0]);
648-
c1 = hn::ConcatUpperUpper(f32, c210b_l23.val[0], c210b_l01.val[0]);
649-
c0 = hn::ConcatLowerLower(f32, c210b_l23.val[1], c210b_l01.val[1]);
650-
b = hn::ConcatUpperUpper(f32, c210b_l23.val[1], c210b_l01.val[1]);
629+
vec_f32 c6543_l01_low, c6543_l01_high;
630+
vec_f32 c6543_l23_low, c6543_l23_high;
631+
zip_f32_lanes(c6543_0, c6543_1, c6543_l01_low, c6543_l01_high);
632+
zip_f32_lanes(c6543_2, c6543_3, c6543_l23_low, c6543_l23_high);
633+
634+
c6 = hn::ConcatLowerLower(f32, c6543_l23_low, c6543_l01_low);
635+
c5 = hn::ConcatUpperUpper(f32, c6543_l23_low, c6543_l01_low);
636+
c4 = hn::ConcatLowerLower(f32, c6543_l23_high, c6543_l01_high);
637+
c3 = hn::ConcatUpperUpper(f32, c6543_l23_high, c6543_l01_high);
638+
639+
vec_f32 c210b_l01_low, c210b_l01_high;
640+
vec_f32 c210b_l23_low, c210b_l23_high;
641+
zip_f32_lanes(c210b_0, c210b_1, c210b_l01_low, c210b_l01_high);
642+
zip_f32_lanes(c210b_2, c210b_3, c210b_l23_low, c210b_l23_high);
643+
644+
c2 = hn::ConcatLowerLower(f32, c210b_l23_low, c210b_l01_low);
645+
c1 = hn::ConcatUpperUpper(f32, c210b_l23_low, c210b_l01_low);
646+
c0 = hn::ConcatLowerLower(f32, c210b_l23_high, c210b_l01_high);
647+
b = hn::ConcatUpperUpper(f32, c210b_l23_high, c210b_l01_high);
651648
} else {
652649
b = lut_32_f32((const float*)lut32x8 + 32*0, idx);
653650
c0 = lut_32_f32((const float*)lut32x8 + 32*1, idx);

0 commit comments

Comments
 (0)
0