8000 Set simdlen based on ATEN_CPU_CAPABILITY (#123514) · pytorch/pytorch@b66e3f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit b66e3f0

Browse files
CaoEpytorchmergebot
authored andcommitted
Set simdlen based on ATEN_CPU_CAPABILITY (#123514)
It is part of #123224. Set simdlen based on the environment ATEN_CPU_CAPABILITY to control CPU vec ISA like eager. Pull Request resolved: #123514 Approved by: https://github.com/jgong5, https://github.com/peterbell10
1 parent df43d58 commit b66e3f0

File tree

6 files changed

+205
-19
lines changed

6 files changed

+205
-19
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import itertools
66
import math
7+
import os
78
import platform
89
import sys
910
import unittest
@@ -66,12 +67,13 @@
6667
check_model = test_torchinductor.check_model
6768

6869
requires_vectorization = unittest.skipUnless(
69-
codecache.valid_vec_isa_list(), "Does not support vectorization"
70+
codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default",
71+
"Does not support vectorization",
7072
)
7173

7274

7375
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
74-
if codecache.valid_vec_isa_list():
76+
if codecache.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default":
7577
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
7678

7779

@@ -1580,6 +1582,71 @@ def fn(x):
15801582
metrics.reset()
15811583
self.common(fn, (value,))
15821584

1585+
@unittest.skipIf(
1586+
not codecache.valid_vec_isa_list()
1587+
or "avx2" in [str(vec_isa) for vec_isa in codecache.valid_vec_isa_list()],
1588+
"Does not support vectorization or not s390x/neon machine",
1589+
)
1590+
@patch("torch.cuda.is_available", lambda: False)
1591+
def test_auto_zvec_neon_simd(self):
1592+
vec_zvec_neon = codecache.valid_vec_isa_list()[0]
1593+
self.assertTrue(vec_zvec_neon.bit_width() == 256)
1594+
1595+
with config.patch({"cpp.simdlen": 0}):
1596+
isa = codecache.pick_vec_isa()
1597+
self.assertFalse(isa)
1598+
1599+
with config.patch({"cpp.simdlen": 1}):
1600+
isa = codecache.pick_vec_isa()
1601+
self.assertFalse(isa)
1602+
1603+
with config.patch({"cpp.simdlen": 257}):
1604+
isa = codecache.pick_vec_isa()
1605+
self.assertFalse(isa)
1606+
1607+
with config.patch({"cpp.simdlen": 256}):
1608+
isa = codecache.pick_vec_isa()
1609+
self.assertTrue(isa == vec_zvec_neon)
1610+
1611+
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
1612+
if pre_var:
1613+
os.environ.pop("ATEN_CPU_CAPABILITY")
1614+
1615+
try:
1616+
with config.patch({"cpp.simdlen": None}):
1617+
isa = codecache.pick_vec_isa()
1618+
self.assertTrue(isa == vec_zvec_neon)
1619+
1620+
with config.patch({"cpp.simdlen": None}):
1621+
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
1622+
isa = codecache.pick_vec_isa()
1623+
self.assertTrue(isa == vec_zvec_neon)
1624+
1625+
with config.patch({"cpp.simdlen": None}):
1626+
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
1627+
isa = codecache.pick_vec_isa()
1628+
self.assertTrue(isa == vec_zvec_neon)
1629+
1630+
with config.patch({"cpp.simdlen": None}):
1631+
os.environ["ATEN_CPU_CAPABILITY"] = "default"
1632+
isa = codecache.pick_vec_isa()
1633+
self.assertFalse(isa)
1634+
1635+
with config.patch({"cpp.simdlen": None}):
1636+
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
1637+
isa = codecache.pick_vec_isa()
1638+
self.assertTrue(isa == vec_zvec_neon)
1639+
1640+
with config.patch({"cpp.simdlen": None}):
1641+
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
1642+
isa = codecache.pick_vec_isa()
1643+
self.assertTrue(isa == vec_zvec_neon)
1644+
finally:
1645+
if pre_var:
1646+
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
1647+
elif os.getenv("ATEN_CPU_CAPABILITY"):
1648+
os.environ.pop("ATEN_CPU_CAPABILITY")
1649+
15831650
@unittest.skipIf(
15841651
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
15851652
"Does not support vectorization or not x86_64 machine",
@@ -1595,13 +1662,6 @@ def test_auto_simd(self):
15951662
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
15961663
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
15971664

1598-
with config.patch({"cpp.simdlen": None}):
1599-
isa = codecache.pick_vec_isa()
1600-
if vec_avx512 in codecache.valid_vec_isa_list():
1601-
self.assertTrue(isa == vec_avx512)
1602-
else:
1603-
self.assertTrue(isa == vec_avx2)
1604-
16051665
with config.patch({"cpp.simdlen": 0}):
16061666
isa = codecache.pick_vec_isa()
16071667
self.assertFalse(isa)
@@ -1631,6 +1691,60 @@ def test_auto_simd(self):
16311691
isa = codecache.pick_vec_isa()
16321692
self.assertTrue(isa == vec_avx2)
16331693

1694+
pre_var = os.getenv("ATEN_CPU_CAPABILITY")
1695+
if pre_var:
1696+
os.environ.pop("ATEN_CPU_CAPABILITY")
1697+
1698+
try:
1699+
with config.patch({"cpp.simdlen": None}):
1700+
isa = codecache.pick_vec_isa()
1701+
if vec_avx512 in codecache.valid_vec_isa_list():
1702+
self.assertTrue(isa == vec_avx512)
1703+
else:
1704+
self.assertTrue(isa == vec_avx2)
1705+
1706+
with config.patch({"cpp.simdlen": None}):
1707+
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
1708+
isa = codecache.pick_vec_isa()
1709+
if vec_avx512 in codecache.valid_vec_isa_list():
1710+
self.assertTrue(isa == vec_avx2)
1711+
elif vec_avx2 in codecache.valid_vec_isa_list():
1712+
self.assertTrue(isa == vec_avx2)
1713+
1714+
with config.patch({"cpp.simdlen": None}):
1715+
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
1716+
isa = codecache.pick_vec_isa()
1717+
if vec_avx512 in codecache.valid_vec_isa_list():
1718+
self.assertTrue(isa == vec_avx512)
1719+
else:
1720+
self.assertTrue(isa == vec_avx2)
1721+
1722+
with config.patch({"cpp.simdlen": None}):
1723+
os.environ["ATEN_CPU_CAPABILITY"] = "default"
1724+
isa = codecache.pick_vec_isa()
1725+
self.assertFalse(isa)
1726+
1727+
with config.patch({"cpp.simdlen": None}):
1728+
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
1729+
isa = codecache.pick_vec_isa()
1730+
if vec_avx512 in codecache.valid_vec_isa_list():
1731+
self.assertTrue(isa == vec_avx512)
1732+
else:
1733+
self.assertTrue(isa == vec_avx2)
1734+
1735+
with config.patch({"cpp.simdlen": None}):
1736+
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
1737+
isa = codecache.pick_vec_isa()
1738+
if vec_avx512 in codecache.valid_vec_isa_list():
1739+
self.assertTrue(isa == vec_avx512)
1740+
else:
1741+
self.assertTrue(isa == vec_avx2)
1742+
finally:
1743+
if pre_var:
1744+
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
1745+
elif os.getenv("ATEN_CPU_CAPABILITY"):
1746+
os.environ.pop("ATEN_CPU_CAPABILITY")
1747+
16341748
@requires_vectorization
16351749
@patch("torch.cuda.is_available", lambda: False)
16361750
def test_masked_fill_softmax(self):
@@ -3371,6 +3485,7 @@ def forward(self, idx, x):
33713485
self.common(m, (idx, x))
33723486
check_metrics_vec_kernel_count(1)
33733487

3488+
@requires_vectorization
33743489
def test_embedding_vec_bf16(self):
33753490
class M(torch.nn.Module):
33763491
def __init__(self):
@@ -3655,7 +3770,7 @@ def fn(x):
36553770
x = torch.randint(0, 100, (819,), dtype=torch.int64)
36563771
metrics.reset()
36573772
self.common(fn, (x,))
3658-
assert metrics.generated_cpp_vec_kernel_count == 1
3773+
check_metrics_vec_kernel_count(1)
36593774

36603775
def test_reduction_float_to_int64(self):
36613776
# https://github.com/pytorch/pytorch/issues/124821
@@ -3665,7 +3780,7 @@ def fn(x):
36653780
x = torch.randint(0, 100, (22, 51), dtype=torch.int64)
36663781
metrics.reset()
36673782
self.common(fn, (x,))
3668-
assert metrics.generated_cpp_vec_kernel_count == 1
3783+
check_metrics_vec_kernel_count(1)
36693784

36703785
@config.patch({"cpp.dynamic_threads": True})
36713786
def test_reduction_with_dynamic_threads(self):

test/inductor/test_extension_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch._dynamo
99
import torch.utils.cpp_extension
1010
from torch._C import FileCheck
11+
from torch._dynamo.testing import expectedFailureScalar
1112

1213
try:
1314
from extension_backends.cpp.extension_codegen_backend import (
@@ -103,6 +104,9 @@ def tearDown(self):
103104
# return the working directory (see setUp)
104105
os.chdir(self.old_working_dir)
105106

107+
# Fails when testing the scalar version
108+
# See https://github.com/pytorch/pytorch/issues/126372.
109+
@expectedFailureScalar
106110
def test_open_device_registration(self):
107111
torch.utils.rename_privateuse1_backend("extension_device")
108112
torch._register_device_module("extension_device", self.module)

test/inductor/test_torchinductor.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torch._dynamo.testing import (
3535
CompileCounterWithBackend,
3636
expectedFailureCodegenDynamic,
37+
expectedFailureScalar,
3738
rand_strided,
3839
same,
3940
skipIfPy312,
@@ -1315,6 +1316,9 @@ def fn(a):
13151316

13161317
self.common(fn, (torch.randn(1024),))
13171318

1319+
# Fails when testing the scalar version
1320+
# See https://github.com/pytorch/pytorch/issues/128029.
1321+
@expectedFailureScalar
13181322
@skipIfRocm
13191323
@config.patch(debug_index_asserts=False)
13201324
def test_neg_index(self):
@@ -1577,16 +1581,40 @@ def test_multilayer_var(self):
15771581
def fn(a):
15781582
return torch.var(a)
15791583

1580-
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)))
1581-
self.common(fn, ((torch.rand((14923), dtype=torch.float32),)))
1584+
atol = None
1585+
rtol = None
1586+
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
1587+
atol = 1e-4
1588+
rtol = 1e-4
1589+
self.common(
1590+
fn,
1591+
((torch.rand((10, 3, 352, 352), dtype=torch.float32),)),
1592+
rtol=rtol,
1593+
atol=atol,
1594+
)
1595+
self.common(
1596+
fn, ((torch.rand((14923), dtype=torch.float32),)), rtol=rtol, atol=atol
1597+
)
15821598

15831599
@skipCPUIf(IS_MACOS, "fails on macos")
15841600
def test_multilayer_var_lowp(self):
15851601
def fn(a):
15861602
return torch.var(a)
15871603

1588-
self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
1589-
self.common(fn, (torch.rand((14923), dtype=torch.float16),))
1604+
atol = None
1605+
rtol = None
1606+
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
1607+
atol = 1e-3
1608+
rtol = 1e-3
1609+
self.common(
1610+
fn,
1611+
(torch.rand((16, 16, 352, 352), dtype=torch.float16),),
1612+
rtol=rtol,
1613+
atol=atol,
1614+
)
1615+
self.common(
1616+
fn, (torch.rand((14923), dtype=torch.float16),), rtol=rtol, atol=atol
1617+
)
15901618

15911619
def test_split_cumsum(self):
15921620
def fn(a):
@@ -8199,7 +8227,7 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
81998227
rand_strided(shape, stride, dtype).requires_grad_(True).add(1)
82008228
for shape, stride, dtype in args
82018229
]
8202-
self.common(forward, args)
8230+
self.common(forward, args, atol=1e-05, rtol=1e-05)
82038231

82048232
@requires_gpu()
82058233
def test_tmp_not_defined_issue3(self):
@@ -9281,6 +9309,7 @@ def func(arg0_1):
92819309
# To support this behavior, we need to allow const-propping tensors that store symint data.
92829310
# For now, dynamo will explicitly graph break when it encounters user code with this behavior.
92839311
@expectedFailureCodegenDynamic
9312+
@expectedFailureScalar
92849313
def test_AllenaiLongformerBase_repro(self):
92859314
def fn(query, scores, window_overlap):
92869315
batch_size, seq_len, num_heads, _ = query.size()
@@ -9316,6 +9345,9 @@ def fn(query, scores, window_overlap):
93169345
opt_fn = torch._dynamo.optimize("inductor")(fn)
93179346
_, code = run_and_get_cpp_code(opt_fn, *args)
93189347
print(code)
9348+
# When testing the scalar version, i.e., ATEN_CPU_CAPABILITY=default,
9349+
# static_cast<int>(256) is not found, but static_cast<int64_t>(256).
9350+
# See https://github.com/pytorch/pytorch/issues/126262.
93199351
FileCheck().check_count(
93209352
"static_cast<int32_t>(256)",
93219353
1,

torch/_dynamo/testing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,12 @@ def expectedFailureDynamicWrapper(fn):
381381
return fn
382382

383383

384+
def expectedFailureScalar(fn):
385+
if os.getenv("ATEN_CPU_CAPABILITY") == "default":
386+
return unittest.expectedFailure(fn)
387+
return fn
388+
389+
384390
def reset_rng_state(use_xla=False):
385391
torch.manual_seed(1337)
386392
random.seed(1337)

torch/_inductor/codecache.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,31 @@ def _check_and_append_supported_isa(
14651465
supported_vec_isa_list = [VecAVX512(), VecAVX2(), VecNEON()]
14661466

14671467

1468+
def get_isa_from_cpu_capability(
1469+
capability: str | None, vec_isa_list: List[VecISA], invalid_vec_isa: InvalidVecISA
1470+
):
1471+
# VSX is not supported in inductor
1472+
capability_to_isa_str = {
1473+
"default": "INVALID_VEC_ISA",
1474+
"neon": "asimd",
1475+
"zvector": "zvector",
1476+
"avx2": "avx2",
1477+
"avx512": "avx512",
1478+
}
1479+
if capability in capability_to_isa_str.keys():
1480+
isa_str = capability_to_isa_str[capability]
1481+
if isa_str == "INVALID_VEC_ISA":
1482+
return invalid_vec_isa
1483+
for vec_isa in vec_isa_list:
1484+
if isa_str == str(vec_isa):
1485+
return vec_isa
1486+
1487+
if capability:
1488+
warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}")
1489+
1490+
return vec_isa_list[0]
1491+
1492+
14681493
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
14691494
# might have too much redundant content that is useless for ISA check. Hence,
14701495
# we only cache some key isa information.
@@ -1507,10 +1532,13 @@ def pick_vec_isa() -> VecISA:
15071532
if not _valid_vec_isa_list:
15081533
return invalid_vec_isa
15091534

1510-
# If the simdlen is None, it indicates determine the vectorization length automatically
1535+
# If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY
1536+
# to control CPU vec ISA
1537+
15111538
if config.cpp.simdlen is None:
1512-
assert _valid_vec_isa_list
1513-
return _valid_vec_isa_list[0]
1539+
return get_isa_from_cpu_capability(
1540+
os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa
1541+
)
15141542

15151543
for isa in _valid_vec_isa_list:
15161544
if config.cpp.simdlen == isa.bit_width():

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <c10/util/generic_math.h>
2525
#include <c10/util/Half.h>
2626
#include <c10/util/TypeCast.h>
27+
#include <ATen/native/Math.h>
2728

2829
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
2930
#define INDUCTOR_USE_VECTOR_TYPES() 1

0 commit comments

Comments
 (0)
0