8000 Add remaining method and tests · pytorch/pytorch@7ab6014 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7ab6014

Browse files
committed
Add remaining method and tests
ghstack-source-id: 2457960 Pull Request resolved: #140057
1 parent 713255d commit 7ab6014

File tree

6 files changed

+277
-41
lines changed

6 files changed

+277
-41
lines changed

test/inductor/test_op_dtype_prop.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Owner(s): ["module: inductor"]
2+
import importlib
3+
import os
4+
import sys
5+
6+
import torch
7+
from torch._dynamo.utils import disable_cache_limit
8+
from torch._inductor import config
9+
from torch._inductor.test_case import TestCase as InductorTestCase
10+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
11+
from torch.testing._internal.common_methods_invocations import op_db
12+
13+
14+
# Make the helper files in test/ importable
15+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16+
sys.path.append(pytorch_test_dir)
17+
18+
19+
importlib.import_module("functorch")
20+
importlib.import_module("filelock")
21+
22+
23+
from torch._inductor.lowering import lowerings
24+
from torch.testing._internal.common_device_type import ops
25+
from torch.testing._internal.inductor_utils import HAS_GPU
26+
27+
28+
unique_pointwise_op_names = set()
29+
30+
for op in lowerings:
31+
if not isinstance(op, torch._ops.OpOverload):
32+
continue
33+
34+
if torch.Tag.pointwise not in op.tags:
35+
continue
36+
37+
if op._schema.is_mutable:
38+
continue
39+
40+
op_name = (op.name().split("::")[-1]).split(".")[0]
41+
unique_pointwise_op_names.add(op_name)
42+
43+
pointwise_ops = [
44+
op
45+
for op in op_db
46+
if op.name in unique_pointwise_op_names and "reduction" not in op.variant_test_name
47+
]
48+
49+
50+
class TestCase(InductorTestCase):
51+
@ops(
52+
pointwise_ops,
53+
allowed_dtypes=(
54+
torch.float32,
55+
torch.float64,
56+
torch.int32,
57+
torch.int64,
58+
torch.bool,
59+
),
60+
)
61+
# @config.patch("triton.codegen_upcast_to_fp32", False) # TODO enable
62+
@config.patch("test_configs.runtime_triton_dtype_assert", True)
63+
@disable_cache_limit()
64+
def test_op_dtype_propagation(self, op, dtype):
65+
def run(op, args, kwargs):
66+
return op(*args, **kwargs)
67+
68+
sample_inputs_itr = op.sample_inputs("cuda", dtype, requires_grad=False)
69+
for sample_input in sample_inputs_itr:
70+
args = (sample_input.input,) + sample_input.args
71+
kwargs = sample_input.kwargs
72+
out = run(op.get_op(), args, kwargs)
73+
out_c = torch.compile(run)(op.get_op(), args, kwargs)
74+
self.assertEqual(out, out_c)
75+
76+
77+
instantiate_device_type_tests(TestCase, globals(), only_for=("cuda",))
78+
79+
if __name__ == "__main__":
80+
from torch._inductor.test_case import run_tests
81+
82+
if HAS_GPU:
83+
run_tests(needs="filelock")

torch/_inductor/codegen/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
)
2424

2525
import sympy
26-
from sympy.printing.printer import Printer
2726

2827
import torch
2928
import torch.fx
29+
from sympy.printing.printer import Printer
3030
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
3131
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
3232
from torch.utils import _pytree as pytree
@@ -2010,12 +2010,13 @@ def inner(*args, **kwargs):
20102010
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
20112011
dtype_handler = DtypePropagationOpsHandler()
20122012

2013+
idx = 0
2014+
20132015
def do_cse(v):
20142016
# TODO - throw on default
20152017
output_dtype = getattr(
20162018
dtype_handler,
20172019
name,
2018-
dtype_handler.default_handler,
20192020
)(*args)
20202021

20212022
csevar = V.kernel.cse.generate(
@@ -2024,7 +2025,20 @@ def do_cse(v):
20242025
bounds=bounds,
20252026
dtype=output_dtype,
20262027
)
2028+
2029+
nonlocal idx
2030+
if config.test_configs.runtime_triton_dtype_assert:
2031+
from torch._inductor.codegen.triton import triton_type
2032+
2033+
if isinstance(output_dtype, (list, tuple)):
2034+
output_dtype = output_dtype[idx]
2035+
V.kernel.compute.writeline(
2036+
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
2037+
)
2038+
idx += 1
2039+
20272040
csevar.update_on_args(name, args, kwargs)
2041+
20282042
return csevar
20292043

20302044
return pytree.tree_map(do_cse, value)

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,9 @@ def sigmoid(x):
10871087
@staticmethod
10881088
def signbit(x):
10891089
# XX: This is wrong for the value -0.0 in floating point
1090-
return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
1090+
return (
1091+
f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0"
1092+
)
10911093

10921094
@staticmethod
10931095
def fmod(a, b):

torch/_inductor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,8 @@ class trace:
13141314
class test_configs:
13151315
force_extern_kernel_in_multi_template = False
13161316

1317+
runtime_triton_dtype_assert = False
1318+
13171319

13181320
if TYPE_CHECKING:
13191321
from torch.utils._config_typing import * # noqa: F401, F403

0 commit comments

Comments
 (0)
0