8000 o Rebase torch.compile SVE logic on torch/main · pytorch/pytorch@4f7aafa · GitHub
[go: up one dir, main page]

8000 Skip to content

Commit 4f7aafa

Browse files
committed
o Rebase torch.compile SVE logic on torch/main
Change-Id: I549769e756603280e6866224563de2b256266f9a
1 parent b85f21f commit 4f7aafa

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,30 @@ struct VecReduceAllSIMD<float, Op> {
107107
};
108108
#endif // defined(__aarch64__)
109109

110+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256)
111+
template <typename Op>
112+
struct VecReduceAllSIMD<float, Op> {
113+
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
114+
using Vec = Vectorized<float>;
115+
Vec v = acc_vec;
116+
// 128-bit shuffle
117+
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
118+
Vec v1 = svtbl_f32(v, ind);
119+
v = vec_fun(v, v1);
120+
// 64-bit shuffle
121+
ind = svdupq_n_u32(2, 3, 0, 1);
122+
v1 = svtbl_f32(v, ind);
123+
v = vec_fun(v, v1);
124+
// 32-bit shuffle
125+
ind = svdupq_n_u32(1, 0, 2, 3);
126+
v1 = svtbl_f32(v, ind);
127+
v = vec_fun(v, v1);
128+
return svlasta(svpfalse(), v);
129+
}
130+
};
131+
#endif // defined(__aarch64__)
132+
133+
110134
template <typename scalar_t, typename Op>
111135
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
112136
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include <c10/util/TypeCast.h>
2929
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3030

31-
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
31+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)
3232
#define INDUCTOR_USE_VECTOR_TYPES() 1
3333
#else
3434
#define INDUCTOR_USE_VECTOR_TYPES() 0

torch/_inductor/cpu_vec_isa.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class VecISA:
5252
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
5353
# making the runtime check unnecessary.
5454
_avx_code = """
55-
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
55+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE)
5656
#include <ATen/cpu/vec/functional.h>
5757
#include <ATen/cpu/vec/vec.h>
5858
#endif
@@ -160,6 +160,20 @@ def __str__(self) -> str:
160160

161161
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
162162

163+
@dataclasses.dataclass
164+
class VecSVE(VecISA):
165+
# this function can be repurposed for SVE with variable vec length
166+
_bit_width = 256
167+
_macro = ["CPU_CAPABILITY_SVE", "CPU_CAPABILITY_SVE256"]
168+
169+
_arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
170+
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
171+
172+
def __str__(self) -> str:
173+
return "asimd"
174+
175+
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
176+
163177

164178
@dataclasses.dataclass
165179
class VecAVX512(VecISA):
@@ -304,9 +318,24 @@ def _check_and_append_supported_isa(
304318

305319
return supported_isa
306320

321+
@functools.lru_cache(maxsize=None)
322+
def _is_arm_neoverse_v1() -> bool:
323+
# reference: https://github.com/ARM-software/ComputeLibrary/blob/main/src/common/cpuinfo/CpuModel.cpp
324+
try:
325+
with open("/proc/cpuinfo") as _cpuinfo:
326+
line = _cpuinfo.readline()
327+
while line:
328+
is_v1 = re.match(r"^CPU\spart(.*):\s0xd40(\n)$", line)
329+
if is_v1:
330+
return True
331+
line = _cpuinfo.readline()
332+
return False
333+
except:
334+
return False
335+
307336

308337
invalid_vec_isa = InvalidVecISA()
309-
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
338+
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
310339

311340

312341
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
@@ -338,7 +367,10 @@ def valid_vec_isa_list() -> List[VecISA]:
338367
elif arch == "ppc64le":
339368
isa_list.append(VecVSX())
340369
elif arch == "aarch64":
341-
isa_list.append(VecNEON())
370+
if _is_arm_neoverse_v1():
371+
isa_list.append(VecSVE())
372+
else:
373+
isa_list.append(VecNEON())
342374
elif arch in ["x86_64", "AMD64"]:
343375
"""
344376
arch value is x86_64 on Linux, and the value is AMD64 on Windows.

0 commit comments

Comments
 (0)
0