8000 [CpuInductor] Enable NEON ISA detection on Linux ARM (#129075) · pytorch/pytorch@b2a9b8d · GitHub
[go: up one dir, main page]

Skip to content

Commit b2a9b8d

Browse files
malfetpytorchmergebot
authored andcommitted
[CpuInductor] Enable NEON ISA detection on Linux ARM (#129075)
Also, cleanup code a bit to use `x in [y, z]` instead of `x == y or x == z` And do not redefine `at_align`, but instead use `alignas(64)` as was suggested in https://github.com/pytorch/pytorch/pull/128686/files#r1639365978 Test plan: `python3 -c "import torch._inductor.codecache as cc; isa = cc.valid_vec_isa_list()[0];print(str(isa), bool(isa))"` Pull Request resolved: #129075 Approved by: https://github.com/jansel
1 parent e0aa992 commit b2a9b8d

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

torch/_inductor/codecache.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,18 +1306,7 @@ class VecISA:
13061306
#include <ATen/cpu/vec/vec.h>
13071307
#endif
13081308
1309-
#ifdef __APPLE__
1310-
// Fix Mac OS UT failed.
1311-
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
1312-
#else
1313-
#if defined(_WIN32)
1314-
#define __at_align__ __declspec(align(64))
1315-
#else
1316-
#define __at_align__ __attribute__((aligned(64)))
1317-
#endif
1318-
1319-
__at_align__ float in_out_ptr0[16] = {0.0};
1320-
#endif
1309+
alignas(64) float in_out_ptr0[16] = {0.0};
13211310
13221311
extern "C" void __avx_chk_kernel() {
13231312
auto tmp0 = at::vec::Vectorized<float>(1);
@@ -1510,12 +1499,11 @@ def _check_and_append_supported_isa(
15101499
# we only cache some key isa information.
15111500
@functools.lru_cache(None)
15121501
def valid_vec_isa_list() -> List[VecISA]:
1502+
isa_list: List[VecISA] = []
15131503
if sys.platform == "darwin" and platform.processor() == "arm":
1514-
return [VecNEON()]
1504+
isa_list.append(VecNEON())
15151505

1516-
isa_list: List[VecISA] = []
1517-
cur_os = sys.platform
1518-
if cur_os != "linux" and cur_os != "win32":
1506+
if sys.platform not in ["linux", "win32"]:
15191507
return isa_list
15201508

15211509
arch = platform.machine()
@@ -1532,17 +1520,16 @@ def valid_vec_isa_list() -> List[VecISA]:
15321520
if re.search(r"[\^ ]+vxe[\$ ]+", group):
15331521
isa_list.append(VecZVECTOR())
15341522
break
1535-
return isa_list
1536-
1537-
if arch == "x86_64" or arch == "AMD64":
1523+
elif arch == "aarch64":
1524+
isa_list.append(VecNEON())
1525+
elif arch in ["x86_64", "AMD64"]:
15381526
"""
15391527
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
15401528
"""
15411529
_cpu_supported_x86_isa = x86_isa_checker()
15421530
for isa in supported_vec_isa_list:
15431531
if str(isa) in _cpu_supported_x86_isa and isa:
15441532
isa_list.append(isa)
1545-
return isa_list
15461533

15471534
return isa_list
15481535

0 commit comments

Comments
 (0)
0