8000 Revert "[Inductor] Expand dtype aware codegen for libdevice and tl.ma… · pytorch/pytorch@f36cccb · GitHub
[go: up one dir, main page]

Skip to content

Commit f36cccb

Browse files
Revert "[Inductor] Expand dtype aware codegen for libdevice and tl.math ops (#140864)"
This reverts commit 80ca6dd. Reverted #140864 on behalf of https://github.com/atalman due to failing internally ([comment](#140864 (comment)))
1 parent 1fa27f6 commit f36cccb

File tree

3 files changed

+40
-263
lines changed

3 files changed

+40
-263
lines changed

test/inductor/test_op_dtype_prop.py

Lines changed: 1 addition & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,16 @@
11
# Owner(s): ["module: inductor"]
22
import importlib
33
import os
4-
import re
54
import sys
65

76
import torch
87
from torch._dynamo.utils import disable_cache_limit
98
from torch._inductor import config
10-
from torch._inductor.codegen.triton import OpDtypeSupport
119
from torch._inductor.test_case import TestCase as InductorTestCase
12-
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
13-
from torch.fx.operator_schemas import get_signature_for_torch_op
10+
from torch._inductor.utils import run_and_get_code
1411
from torch.testing import FileCheck
1512
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1613
from torch.testing._internal.common_methods_invocations import op_db
17-
from torch.testing._internal.common_utils import parametrize
18-
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
1914

2015

2116
# Make the helper files in test/ importable
@@ -80,137 +75,6 @@ def run(op, args, kwargs):
8075
out_c = torch.compile(run)(op.get_op(), args, kwargs)
8176
self.assertEqual(out, out_c)
8277

83-
@requires_gpu()
84-
@parametrize("upcast_to_fp32", [False, True])
85-
@config.patch("triton.use_block_ptr", True)
86-
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
87-
@torch.compile
88-
def func(a, b, c, d):
89-
return a * b * c * d
90-
91-
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4
92-
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
93-
func_opt = torch._dynamo.optimize("inductor")(func)
94-
code = run_and_get_triton_code(func_opt, *inps)
95-
fp32_cast_in_code = "to(tl.float32)" in code
96-
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
97-
98-
def test_op_dtype_support(self):
99-
"""
100-
Triton codegen upcasts values to float32 for certain ops.
101-
Check that those ops have accurate dtype information.
102-
"""
103-
104-
for op_name in [
105-
"rsqrt",
106-
"sqrt",
107-
"isnan",
108-
"floor",
109-
"ceil",
110-
"tan",
111-
"atan",
112-
"atanh",
113-
"sigmoid",
114-
"log2",
115-
"log10",
116-
"cosh",
117-
"sinh",
118-
"acosh",
119-
"asinh",
120-
"asin",
121-
"acos",
122-
"asinh",
123-
"erf",
124-
"lgamma",
125-
"sin",
126-
"cos",
127-
"exp",
128-
"expm1",
129-
"exp2",
130-
"abs",
131-
"hypot",
132-
"nextafter",
133-
]:
134-
# These ops do not support float16 and bfloat16.
135-
supported_dtypes = OpDtypeSupport.supported_dtypes[op_name]
136-
self.assertNotIn(torch.float16, supported_dtypes)
137-
self.assertNotIn(torch.bfloat16, supported_dtypes)
138-
139-
# These ops should support float32 and float64.
140-
self.assertIn(torch.float32, supported_dtypes)
141-
self.assertIn(torch.float64, supported_dtypes)
142-
143-
@requires_gpu()
144-
@parametrize("op_name", OpDtypeSupport.supported_dtypes)
145-
@parametrize("load_upcast_to_fp32", [False, True])
146-
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
147-
@config.patch("triton.use_block_ptr", True)
148-
def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtype):
149-
"""
150-
Test dtype aware codegen for some tl.math/libdevice calls.
151-
Operands should be upcast to float32, and the output should be downcast to float16.
152-
"""
153-
154-
# Check if the op's output should be upcasted/downcasted.
155-
supported_dtypes = OpDtypeSupport.supported_dtypes[op_name]
156-
convert_output = OpDtypeSupport.convert_outputs[op_name]
157-
self.assertNotIn(input_dtype, supported_dtypes)
158-
159-
# Retrieve the corresponding torch op.
160-
torch_op_name = op_name.removeprefix("libdevice_")
161-
op = getattr(torch, torch_op_name)
162-
163-
# Edge case: torch.round maps to libdevice.nearbyint.
164-
triton_op_name_overrides = {
165-
"round": "nearbyint",
166-
}
167-
override = triton_op_name_overrides.get(op_name)
168-
triton_op_name = override if override is not None else torch_op_name
169-
170-
# Get the number of args for the op.
171-
signatures = get_signature_for_torch_op(op)
172-
num_args = len(signatures[0].parameters)
173-
174-
# Test codegen and check for casts.
175-
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * num_args
176-
tl_dtype_str = str(input_dtype).replace("torch", "tl")
177-
with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32):
178-
compiled = torch._dynamo.optimize("inductor")(op)
179-
code = run_and_get_triton_code(compiled, *inps)
180-
181-
# Search the code with a regex.
182-
# Example code: libdevice.floor(tmp3.to(tl.float32)).to(tl.float16)
183-
output_cast = rf"\.to\({tl_dtype_str}\)" if convert_output else ""
184-
pattern = rf"{triton_op_name}\(.*\.to\(tl\.float32\)\){output_cast}"
185-
cast_in_code = re.search(pattern, code, re.MULTILINE) is not None
186-
self.assertNotEqual(cast_in_code, load_upcast_to_fp32)
187-
188-
@config.patch("triton.codegen_upcast_to_fp32", False)
189-
def test_binary_math_mixed_precision(self):
190-
"""
191-
Test a binary math operator where only one input needs to be upcast.
192-
"""
193-
# Create inputs of different dtypes.
194-
inputs = [
195-
torch.randn(8, device=GPU_TYPE, dtype=dtype)
196-
for dtype in (torch.float16, torch.float32)
197-
]
198-
199-
func = torch.hypot
200-
compiled = torch.compile(backend="inductor")(func)
201-
result, (code,) = run_and_get_code(compiled, *inputs)
202-
203-
# Check accuracy.
204-
ref = func(*inputs)
205-
self.assertTrue(torch.allclose(ref, result))
206-
207-
# Check for exactly one upcast.
208-
num_upcasts = code.count(".to(tl.float32)")
209-
self.assertEqual(num_upcasts, 1)
210-
211-
# There should be no downcast, since the input is promoted to float32.
212-
self.assertNotIn(".to(tl.float16)", code)
213-
21478
@config.patch("test_configs.runtime_triton_dtype_assert", True)
21579
def test_constant(self):
21680
def fn():

test/inductor/test_torchinductor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12498,6 +12498,37 @@ def f(x, mask):
1249812498
# it does not move the tensor constructor to cuda and keeps it on CPU.
1249912499
self.assertFalse("empty_strided_cuda(()" in code)
1250012500

12501+
@requires_gpu()
12502+
@parametrize("upcast_to_fp32", [False, True])
12503+
@config.patch("triton.use_block_ptr", True)
12504+
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
12505+
@torch.compile
12506+
def func(a, b, c, d):
12507+
return a * b * c * d
12508+
12509+
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4
12510+
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
12511+
func_opt = torch._dynamo.optimize("inductor")(func)
12512+
code = run_and_get_triton_code(func_opt, *inps)
12513+
fp32_cast_in_code = "to(tl.float32)" in code
12514+
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
12515+
12516+
@requires_gpu()
12517+
@parametrize("load_upcast_to_fp32", [False, True])
12518+
@parametrize("input_dtype", [torch.float16, torch.bfloat16])
12519+
@config.patch("triton.use_block_ptr", True)
12520+
def test_dtype_aware_codegen(self, load_upcast_to_fp32, input_dtype):
12521+
@torch.compile
12522+
def func(a, b, c, d):
12523+
return torch.sqrt(a * b * c * d)
12524+
12525+
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * 4
12526+
with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32):
12527+
func_opt = torch._dynamo.optimize("inductor")(func)
12528+
code = run_and_get_triton_code(func_opt, *inps)
12529+
libdevice_cast_in_code = "libdevice.sqrt(tmp3.to(tl.float32))" in code
12530+
self.assertNotEqual(libdevice_cast_in_code, load_upcast_to_fp32)
12531+
1250112532
@config.patch("triton.use_block_ptr", False)
1250212533
def test_evict_last_non_coalesced_loads(self):
1250312534
@torch.compile

0 commit comments

Comments
 (0)
0