8000 add is_vec_specialized_for (#152365) · pytorch/pytorch@a2e2f90 · GitHub
[go: up one dir, main page]

Skip to content

Commit a2e2f90

Browse files
swolchokpytorchmergebot
authored andcommitted
add is_vec_specialized_for (#152365)
Let people detect at compile time whether Vectorized is specialized for a given type. See vec_base.h. Differential Revision: [D73802129](https://our.internmc.facebook.com/intern/diff/D73802129/) Pull Request resolved: #152365 Approved by: https://github.com/jgong5, https://github.com/malfet
1 parent ae0e8f0 commit a2e2f90

35 files changed

+198
-1
lines changed

aten/src/ATen/cpu/vec/sve/vec_bfloat16.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ inline namespace CPU_CAPABILITY {
2020

2121
#if defined(CPU_CAPABILITY_SVE256) && defined(__ARM_FEATURE_BF16)
2222

23+
template <>
24+
struct is_vec_specialized_for<BFloat16> : std::bool_constant<true> {};
25+
2326
template <>
2427
class Vectorized<BFloat16> {
2528
private:

aten/src/ATen/cpu/vec/sve/vec_double.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ inline namespace CPU_CAPABILITY {
2424

2525
#if defined(CPU_CAPABILITY_SVE)
2626

27+
template <>
28+
struct is_vec_specialized_for<double> : std::bool_constant<true> {};
29+
2730
template <>
2831
class Vectorized<double> {
2932
private:

aten/src/ATen/cpu/vec/sve/vec_float.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ inline namespace CPU_CAPABILITY {
2424

2525
#if defined(CPU_CAPABILITY_SVE)
2626

27+
template <>
28+
struct is_vec_specialized_for<float> : std::bool_constant<true> {};
29+
2730
template <>
2831
class Vectorized<float> {
2932
private:

aten/src/ATen/cpu/vec/sve/vec_int.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ inline namespace CPU_CAPABILITY {
1919

2020
#define VEC_INT_SVE_TEMPLATE(vl, bit) \
2121
template <> \
22+
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
23+
\
24+
template <> \
2225
class Vectorized<int##bit##_t> { \
2326
private: \
2427
vls_int##bit##_t values; \

aten/src/ATen/cpu/vec/sve/vec_qint.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ struct VectorizedQuantizedConverter {
142142
VectorizedQuantizedConverter() {}
143143
};
144144

145+
template <>
146+
struct is_vec_specialized_for<c10::qint32> : std::bool_constant<true> {};
147+
145148
template <>
146149
struct Vectorized<c10::qint32> : public VectorizedQuantizedConverter<
147150
c10::qint32,
@@ -302,6 +305,9 @@ Vectorized<c10::qint32> inline operator+(
302305
return retval;
303306
}
304307

308+
template <>
309+
struct is_vec_specialized_for<c10::qint8> : std::bool_constant<true> {};
310+
305311
template <>
306312
struct Vectorized<c10::qint8> : public VectorizedQuantizedConverter<
307313
c10::qint8,
@@ -442,6 +448,9 @@ Vectorized<c10::qint8> inline maximum(
442448
return a.maximum(b);
443449
}
444450

451+
template <>
452+
struct is_vec_specialized_for<c10::quint8> : std::bool_constant<true> {};
453+
445454
template <>
446455
struct Vectorized<c10::quint8> : public VectorizedQuantizedConverter<
447456
c10::quint8,

aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ struct BlendBFloat16Regs<index, false> {
128128
}
129129
};
130130

131+
template <>
132+
struct is_vec_specialized_for<c10::BFloat16> : std::bool_constant<true> {};
133+
131134
template <>
132135
class Vectorized<c10::BFloat16> : public Vectorized16<
133136
at_bfloat16x8_t,

aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ struct BlendRegs<index, false> {
6969
}
7070
};
7171

72+
template <>
73+
struct is_vec_specialized_for<float> : std::bool_constant<true> {};
74+
7275
template <>
7376
class Vectorized<float> {
7477
private:

aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ struct BlendHalfRegs<index, false> {
5858
}
5959
};
6060

61+
template <>
62+
struct is_vec_specialized_for<c10::Half> : std::bool_constant<true> {};
63+
6164
// On ARM, Half type supports float16_t->Half constructor and Half->float16_t
6265
// conversion
6366
template <>

aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ inline namespace CPU_CAPABILITY {
1212

1313
#if defined(CPU_CAPABILITY_AVX2)
1414

15+
template <>
16+
struct is_vec_specialized_for<BFloat16> : std::bool_constant<true> {};
17+
1518
template <>
1619
class Vectorized<BFloat16> : public Vectorized16<BFloat16> {
1720
public:

aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ inline namespace CPU_CAPABILITY {
1919

2020
#if defined(CPU_CAPABILITY_AVX2)
2121

22+
template <>
23+
struct is_vec_specialized_for<c10::complex<double>> : std::bool_constant<true> {
24+
};
25+
2226
template <>
2327
class Vectorized<c10::complex<double>> {
2428
private:

0 commit comments

Comments
 (0)
0