@@ -52,7 +52,7 @@ class VecISA:
52
52
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
53
53
# making the runtime check unnecessary.
54
54
_avx_code = """
55
- #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
55
+ #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE)
56
56
#include <ATen/cpu/vec/functional.h>
57
57
#include <ATen/cpu/vec/vec.h>
58
58
#endif
@@ -160,6 +160,20 @@ def __str__(self) -> str:
160
160
161
161
__hash__ : Callable [[VecISA ], Any ] = VecISA .__hash__
162
162
163
+ @dataclasses .dataclass
164
+ class VecSVE (VecISA ):
165
+ # this function can be repurposed for SVE with variable vec length
166
+ _bit_width = 256
167
+ _macro = ["CPU_CAPABILITY_SVE" , "CPU_CAPABILITY_SVE256" ]
168
+
169
+ _arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
170
+ _dtype_nelements = {torch .float : 8 , torch .bfloat16 : 16 , torch .float16 : 16 }
171
+
172
+ def __str__ (self ) -> str :
173
+ return "asimd"
174
+
175
+ __hash__ : Callable [[VecISA ], Any ] = VecISA .__hash__
176
+
163
177
164
178
@dataclasses .dataclass
165
179
class VecAVX512 (VecISA ):
@@ -304,9 +318,24 @@ def _check_and_append_supported_isa(
304
318
305
319
return supported_isa
306
320
321
+ @functools .lru_cache (maxsize = None )
322
+ def _is_arm_neoverse_v1 () -> bool :
323
+ # reference: https://github.com/ARM-software/ComputeLibrary/blob/main/src/common/cpuinfo/CpuModel.cpp
324
+ try :
325
+ with open ("/proc/cpuinfo" ) as _cpuinfo :
326
+ line = _cpuinfo .readline ()
327
+ while line :
328
+ is_v1 = re .match (r"^CPU\spart(.*):\s0xd40(\n)$" , line )
329
+ if is_v1 :
330
+ return True
331
+ line = _cpuinfo .readline ()
332
+ return False
333
+ except :
334
+ return False
335
+
307
336
308
337
invalid_vec_isa = InvalidVecISA ()
309
- supported_vec_isa_list = [VecAMX (), VecAVX512 (), VecAVX2 (), VecNEON ()]
338
+ supported_vec_isa_list = [VecAMX (), VecAVX512 (), VecAVX2 (), VecNEON (), VecSVE () ]
310
339
311
340
312
341
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
@@ -338,7 +367,10 @@ def valid_vec_isa_list() -> List[VecISA]:
338
367
elif arch == "ppc64le" :
339
368
isa_list .append (VecVSX ())
340
369
elif arch == "aarch64" :
341
- isa_list .append (VecNEON ())
370
+ if _is_arm_neoverse_v1 ():
371
+ isa_list .append (VecSVE ())
372
+ else :
373
+ isa_list .append (VecNEON ())
342
374
elif arch in ["x86_64" , "AMD64" ]:
343
375
"""
344
376
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
0 commit comments