8000 [cpu][vec] support reduce ops for add and max (#144065) · pytorch/pytorch@a1ae8fa · GitHub
[go: up one dir, main page]

Skip to content

Commit a1ae8fa

Browse files
Valentine233pytorchmergebot
authored andcommitted
[cpu][vec] support reduce ops for add and max (#144065)
### Description During the support of INT8 SDPA pytorch/ao#1372, we find that `at::vec::vec_reduce_all<int32_t>` would go into slow scalar path when doing sum and max. So here, we support the two reduce-related ops `reduce_add` and `reduce_max` for `vec512` and `vec256`, using the Sequence instructions. ### Details - Support vectorized `reduce_add` and `reduce_max` for dtypes `int32` and `float32`, using the Sequence instructions; - Implement the scalar version for fallback path in vec base; - Add the operator `reduce` in vec base, in order to simplify the codes. Pull Request resolved: #144065 Approved by: https://github.com/mingfeima
1 parent 55dc61d commit a1ae8fa

File tree

5 files changed

+95
-0
lines changed

5 files changed

+95
-0
lines changed

aten/src/ATen/cpu/vec/vec256/vec256_float.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,32 @@ template <> class Vectorized<float> {
380380
Vectorized<float> pow(const Vectorized<float> &b) const {
381381
return Vectorized<float>(Sleef_powf8_u10(values, b));
382382
}
383+
float reduce_add() const {
384+
auto v = values;
385+
// 128-bit shuffle
386+
auto v1 = _mm256_permute2f128_ps(v, v, 0x1);
387+
v = _mm256_add_ps(v, v1);
388+
// 64-bit shuffle
389+
v1 = _mm256_shuffle_ps(v, v, 0x4E);
390+
v = _mm256_add_ps(v, v1);
391+
// 32-bit shuffle
392+
v1 = _mm256_shuffle_ps(v, v, 0xB1);
393+
v = _mm256_add_ps(v, v1);
394+
return _mm256_cvtss_f32(v);
395+
}
396+
float reduce_max() const {
397+
auto v = values;
398+
// 128-bit shuffle
399+
auto v1 = _mm256_permute2f128_ps(v, v, 0x1);
400+
v = _mm256_max_ps(v, v1);
401+
// 64-bit shuffle
402+
v1 = _mm256_shuffle_ps(v, v, 0x4E);
403+
v = _mm256_max_ps(v, v1);
404+
// 32-bit shuffle
405+
v1 = _mm256_shuffle_ps(v, v, 0xB1);
406+
v = _mm256_max_ps(v, v1);
407+
return _mm256_cvtss_f32(v);
408+
}
383409
// Comparison using the _CMP_**_OQ predicate.
384410
// `O`: get false if an operand is NaN
385411
// `Q`: do not raise if an operand is NaN

aten/src/ATen/cpu/vec/vec256/vec256_int.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,34 @@ class Vectorized<int32_t> : public Vectorizedi {
251251
return *this;
252252
}
253253
Vectorized<int32_t> neg() const;
254+
int32_t reduce_add() const {
255+
auto v = values;
256+
// 128-bit shuffle
257+
auto v1 = _mm256_permute2f128_si256(v, v, 0x1);
258+
v = _mm256_add_epi32(v, v1);
259+
// 64-bit shuffle
260+
v1 = _mm256_shuffle_epi32(v, 0x4E);
261+
v = _mm256_add_epi32(v, v1);
262+
// 32-bit shuffle
263+
v1 = _mm256_shuffle_epi32(v, 0xB1);
264+
v = _mm256_add_epi32(v, v1);
265+
__m128i lo = _mm256_castsi256_si128(v);
266+
return _mm_cvtsi128_si32(lo);
267+
}
268+
int32_t reduce_max() const {
269+
auto v = values;
270+
// 128-bit shuffle
271+
auto v1 = _mm256_permute2f128_si256(v, v, 0x1);
272+
v = _mm256_max_epi32(v, v1);
273+
// 64-bit shuffle
274+
v1 = _mm256_shuffle_epi32(v, 0x4E);
275+
v = _mm256_max_epi32(v, v1);
276+
// 32-bit shuffle
277+
v1 = _mm256_shuffle_epi32(v, 0xB1);
278+
v = _mm256_max_epi32(v, v1);
279+
__m128i lo = _mm256_castsi256_si128(v);
280+
return _mm_cvtsi128_si32(lo);
281+
}
254282
Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const {
255283
return _mm256_cmpeq_epi32(values, other.values);
256284
}

aten/src/ATen/cpu/vec/vec512/vec512_float.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,12 @@ template <> class Vectorized<float> {
403403
Vectorized<float> pow(const Vectorized<float> &b) const {
404404
return Vectorized<float>(Sleef_powf16_u10(values, b));
405405
}
406+
float reduce_add() const {
407+
return _mm512_reduce_add_ps(values);
408+
}
409+
float reduce_max() const {
410+
return _mm512_reduce_max_ps(values);
411+
}
406412
// Comparison using the _CMP_**_OQ predicate.
407413
// `O`: get false if an operand is NaN
408414
// `Q`: do not raise if an operand is NaN

aten/src/ATen/cpu/vec/vec512/vec512_int.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ class Vectorized<int32_t> : public Vectorizedi {
277277
return *this;
278278
}
279279
Vectorized<int32_t> neg() const;
280+
int32_t reduce_add() const {
281+
return _mm512_reduce_add_epi32(values);
282+
}
283+
int32_t reduce_max() const {
284+
return _mm512_reduce_max_epi32(values);
285+
}
280286
Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const {
281287
auto mask = _mm512_cmpeq_epi32_mask(values, other.values);
282288
return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF);

aten/src/ATen/cpu/vec/vec_base.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,15 @@ struct Vectorized {
294294
}
295295
return ret;
296296
}
297+
T reduce(T (*const f)(T)) const {
298+
T ret = 0;
299+
for (int64_t i = 0; i < size(); i++) {
300+
ret = f(ret, values[i]);
301+
if (++i < size())
302+
ret = f(ret, values[i]);
303+
}
304+
return ret;
305+
}
297306
#else
298307
Vectorized<T> map(T (*const f)(T)) const {
299308
Vectorized<T> ret;
@@ -302,6 +311,13 @@ struct Vectorized {
302311
}
303312
return ret;
304313
}
314+
T reduce(T (*const f)(T)) const {
315+
T ret = 0;
316+
for (int64_t i = 0; i != size(); i++) {
317+
ret = f(ret, values[i]);
318+
}
319+
return ret;
320+
}
305321
#endif
306322
Vectorized<T> map(T (*const f)(const T &)) const {
307323
Vectorized<T> ret;
@@ -310,6 +326,13 @@ struct Vectorized {
310326
}
311327
return ret;
312328
}
329+
T reduce(T (*const f)(const T &)) const {
330+
T ret = 0;
331+
for (int64_t i = 0; i != size(); i++) {
332+
ret = f(ret, values[i]);
333+
}
334+
return ret;
335+
}
313336
template <typename other_t_abs = T,
314337
typename std::enable_if_t<!is_floating_point_v<other_t_abs> && !c10::is_complex<other_t_abs>::value, int> = 0>
315338
Vectorized<T> abs() const {
@@ -585,6 +608,12 @@ struct Vectorized {
585608
}
586609
return ret;
587610
}
611+
T reduce_add() const {
612+
return reduce([](T x, T y) -> T { return x + y; });
613+
}
614+
T reduce_max() const {
615+
return reduce(std::max);
616+
}
588617
private:
589618
template <typename Op>
590619
inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {

0 commit comments

Comments
 (0)
0