8000 [Inductor] Expand dtype aware codegen for libdevice and tl.math ops (… · pytorch/pytorch@80ca6dd · GitHub
[go: up one dir, main page]

Skip to content

Commit 80ca6dd

Browse files
blaine-ristereellison
authored andcommitted
[Inductor] Expand dtype aware codegen for libdevice and tl.math ops (#140864)
# Feature Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16. This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast. # Test Plan Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex. Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts. This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file. Pull Request resolved: #140864 Approved by: https://github.com/eellison, https://github.com/arui-meta Co-authored-by: eellison <elias.ellison@gmail.com>
1 parent 0602676 commit 80ca6dd

File tree

3 files changed

+263
-40
lines changed

3 files changed

+263
-40
lines changed

test/inductor/test_op_dtype_prop.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
# Owner(s): ["module: inductor"]
22
import importlib
33
import os
4+
import re
45
import sys
56

67
import torch
78
from torch._dynamo.utils import disable_cache_limit
89
from torch._inductor import config
10+
from torch._inductor.codegen.triton import OpDtypeSupport
911
from torch._inductor.test_case import TestCase as InductorTestCase
10-
from torch._inductor.utils import run_and_get_code
12+
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
13+
from torch.fx.operator_schemas import get_signature_for_torch_op
1114
from torch.testing import FileCheck
1215
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1316
from torch.testing._internal.common_methods_invocations import op_db
17+
from torch.testing._internal.common_utils import parametrize
18+
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
1419

1520

1621
# Make the helper files in test/ importable
@@ -75,6 +80,137 @@ def run(op, args, kwargs):
7580
out_c = torch.compile(run)(op.get_op(), args, kwargs)
7681
self.assertEqual(out, out_c)
7782

83+
@requires_gpu()
84+
@parametrize("upcast_to_fp32", [False, True])
85+
@config.patch("triton.use_block_ptr", True)
86+
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
87+
@torch.compile
88+
def func(a, b, c, d):
89+
return a * b * c * d
90+
91+
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4
92+
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
93+
func_opt = torch._dynamo.optimize("inductor")(func)
94+
code = run_and_get_triton_code(func_opt, *inps)
95+
fp32_cast_in_code = "to(tl.float32)" in code
96+
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
97+
98+
def test_op_dtype_support(self):
99+
"""
100+
Triton codegen upcasts values to float32 for certain ops.
101+
Check that those ops have accurate dtype information.
102+
"""
103+
104+
for op_name in [
105+
"rsqrt",
106+
"sqrt",
107+
"isnan",
108+
"floor",
109+
"ceil",
110+
"tan",
111+
"atan",
112+
"atanh",
113+
"sigmoid",
114+
"log2",
115+
"log10",
116+
"cosh",
117+
"sinh",
118+
"acosh",
119+
"asinh",
120+
"asin",
121+
"acos",
122+
"asinh",
123+
"erf",
124+
"lgamma",
125+
"sin",
126+
"cos",
127+
"exp",
128+
"expm1",
129+
"exp2",
130+
"abs",
131+
"hypot",
132+
"nextafter",
133+
]:
134+
# These ops do not support float16 and bfloat16.
135+
supported_dtypes = OpDtypeSupport.supported_dtypes[op_name]
136+
self.assertNotIn(torch.float16, supported_dtypes)
137+
self.assertNotIn(torch.bfloat16, supported_dtypes)
138+
139+
# These ops should support float32 and float64.
140+
self.assertIn(torch.float32, supported_dtypes)
141+
self.assertIn(torch.float64, supported_dtypes)
142+
143+
@requires_gpu()
144+
@parametrize("op_name", OpDtypeSupport.supported_dtypes)
145+
@parametrize("load_upcast_to_fp32", [False, True])
146+
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
147+
@config.patch("triton.use_block_ptr", True)
148+
def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtype):
149+
"""
150+
Test dtype aware codegen for some tl.math/libdevice calls.
151+
Operands should be upcast to float32, and the output should be downcast to float16.
152+
"""
153+
154+
# Check if the op's output should be upcasted/downcasted.
155+
supported_dtypes = OpDtypeSupport.supported_dtypes[op_name]
156+
convert_output = OpDtypeSupport.convert_outputs[op_name]
157+
self.assertNotIn(input_dtype, supported_dtypes)
158+
159+
# Retrieve the corresponding torch op.
160+
torch_op_name = op_name.removeprefix("libdevice_")
161+
op = getattr(torch, torch_op_name)
162+
163+
# Edge case: torch.round maps to libdevice.nearbyint.
164+
triton_op_name_overrides = {
165+
"round": "nearbyint",
166+
}
167+
override = triton_op_name_overrides.get(op_name)
168+
triton_op_name = override if override is not None else torch_op_name
169+
170+
# Get the number of args for the op.
171+
signatures = get_signature_for_torch_op(op)
172+
num_args = len(signatures[0].parameters)
173+
174+
# Test codegen and check for casts.
175+
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * num_args
176+
tl_dtype_str = str(input_dtype).replace("torch", "tl")
177+
with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32):
178+
compiled = torch._dynamo.optimize("inductor")(op)
179+
code = run_and_get_triton_code(compiled, *inps)
180+
181+
# Search the code with a regex.
182+
# Example code: libdevice.floor(tmp3.to(tl.float32)).to(tl.float16)
183+
output_cast = rf"\.to\({tl_dtype_str}\)" if convert_output else ""
184+
pattern = rf"{triton_op_name}\(.*\.to\(tl\.float32\)\){output_cast}"
185+
cast_in_code = re.search(pattern, code, re.MULTILINE) is not None
186+
self.assertNotEqual(cast_in_code, load_upcast_to_fp32)
187+
188+
@config.patch("triton.codegen_upcast_to_fp32", False)
189+
def test_binary_math_mixed_precision(self):
190+
"""
191+
Test a binary math operator where only one input needs to be upcast.
192+
"""
193+
# Create inputs of different dtypes.
194+
inputs = [
195+
torch.randn(8, device=GPU_TYPE, dtype=dtype)
196+
for dtype in (torch.float16, torch.float32)
197+
]
198+
199+
func = torch.hypot
200+
compiled = torch.compile(backend="inductor")(func)
201+
result, (code,) = run_and_get_code(compiled, *inputs)
202+
203+
# Check accuracy.
204+
ref = func(*inputs)
205+
self.assertTrue(torch.allclose(ref, result))
206+
207+
# Check for exactly one upcast.
208+
num_upcasts = code.count(".to(tl.float32)")
209+
self.assertEqual(num_upcasts, 1)
210+
211+
# There should be no downcast, since the input is promoted to float32.
212+
self.assertNotIn(".to(tl.float16)", code)
213+
78214
@config.patch("test_configs.runtime_triton_dtype_assert", True)
79215
def test_constant(self):
80216
def fn():

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12491,37 +12491,6 @@ def f(x, mask):
1249112491
# it does not move the tensor constructor to cuda and keeps it on CPU.
1249212492
self.assertFalse("empty_strided_cuda(()" in code)
1249312493

12494-
@requires_gpu()
12495-
@parametrize("upcast_to_fp32", [False, True])
12496-
@config.patch("triton.use_block_ptr", True)
12497-
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
12498-
@torch.compile
12499-
def func(a, b, c, d):
12500-
return a * b * c * d
12501-
12502-
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4
12503-
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
12504-
func_opt = torch._dynamo.optimize("inductor")(func)
12505-
code = run_and_get_triton_code(func_opt, *inps)
12506-
fp32_cast_in_code = "to(tl.float32)" in code
12507-
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
12508-
12509-
@requires_gpu()
12510-
@parametrize("load_upcast_to_fp32", [False, True])
12511-
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
12512-
@config.patch("triton.use_block_ptr", True)
12513-
def test_dtype_aware_codegen(self, load_upcast_to_fp32, input_dtype):
12514-
@torch.compile
12515-
def func(a, b, c, d):
12516-
return torch.sqrt(a * b * c * d)
12517-
12518-
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * 4
12519-
with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32):
12520-
func_opt = torch._dynamo.optimize("inductor")(func)
12521-
code = run_and_get_triton_code(func_opt, *inps)
12522-
libdevice_cast_in_code = "libdevice.sqrt(tmp3.to(tl.float32))" in code
12523-
self.assertNotEqual(libdevice_cast_in_code, load_upcast_to_fp32)
12524-
1252512494
@config.patch("triton.use_block_ptr", False)
1252612495
def test_evict_last_non_coalesced_loads(self):
1252712496
@torch.compile

0 commit comments

Comments
 (0)
0