|
1 | 1 | # Owner(s): ["module: inductor"]
|
2 | 2 | import importlib
|
3 | 3 | import os
|
4 |
| -import re |
5 | 4 | import sys
|
6 | 5 |
|
7 | 6 | import torch
|
8 | 7 | from torch._dynamo.utils import disable_cache_limit
|
9 | 8 | from torch._inductor import config
|
10 |
| -from torch._inductor.codegen.triton import OpDtypeSupport |
11 | 9 | from torch._inductor.test_case import TestCase as InductorTestCase
|
12 |
| -from torch._inductor.utils import run_and_get_code, run_and_get_triton_code |
13 |
| -from torch.fx.operator_schemas import get_signature_for_torch_op |
| 10 | +from torch._inductor.utils import run_and_get_code |
14 | 11 | from torch.testing import FileCheck
|
15 | 12 | from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
16 | 13 | from torch.testing._internal.common_methods_invocations import op_db
|
17 |
| -from torch.testing._internal.common_utils import parametrize |
18 |
| -from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu |
19 | 14 |
|
20 | 15 |
|
21 | 16 | # Make the helper files in test/ importable
|
@@ -80,137 +75,6 @@ def run(op, args, kwargs):
|
80 | 75 | out_c = torch.compile(run)(op.get_op(), args, kwargs)
|
81 | 76 | self.assertEqual(out, out_c)
|
82 | 77 |
|
83 |
| - @requires_gpu() |
84 |
| - @parametrize("upcast_to_fp32", [False, True]) |
85 |
| - @config.patch("triton.use_block_ptr", True) |
86 |
| - def test_codegen_upcast_to_fp32(self, upcast_to_fp32): |
87 |
| - @torch.compile |
88 |
| - def func(a, b, c, d): |
89 |
| - return a * b * c * d |
90 |
| - |
91 |
| - inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4 |
92 |
| - with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32): |
93 |
| - func_opt = torch._dynamo.optimize("inductor")(func) |
94 |
| - code = run_and_get_triton_code(func_opt, *inps) |
95 |
| - fp32_cast_in_code = "to(tl.float32)" in code |
96 |
| - self.assertEqual(fp32_cast_in_code, upcast_to_fp32) |
97 |
| - |
98 |
| - def test_op_dtype_support(self): |
99 |
| - """ |
100 |
| - Triton codegen upcasts values to float32 for certain ops. |
101 |
| - Check that those ops have accurate dtype information. |
102 |
| - """ |
103 |
| - |
104 |
| - for op_name in [ |
105 |
| - "rsqrt", |
106 |
| - "sqrt", |
107 |
| - "isnan", |
108 |
| - "floor", |
109 |
| - "ceil", |
110 |
| - "tan", |
111 |
| - "atan", |
112 |
| - "atanh", |
113 |
| - "sigmoid", |
114 |
| - "log2", |
115 |
| - "log10", |
116 |
| - "cosh", |
117 |
| - "sinh", |
118 |
| - "acosh", |
119 |
| - "asinh", |
120 |
| - "asin", |
121 |
| - "acos", |
122 |
| - "asinh", |
123 |
| - "erf", |
124 |
| - "lgamma", |
125 |
| - "sin", |
126 |
| - "cos", |
127 |
| - "exp", |
128 |
| - "expm1", |
129 |
| - "exp2", |
130 |
| - "abs", |
131 |
| - "hypot", |
132 |
| - "nextafter", |
133 |
| - ]: |
134 |
| - # These ops do not support float16 and bfloat16. |
135 |
| - supported_dtypes = OpDtypeSupport.supported_dtypes[op_name] |
136 |
| - self.assertNotIn(torch.float16, supported_dtypes) |
137 |
| - self.assertNotIn(torch.bfloat16, supported_dtypes) |
138 |
| - |
139 |
| - # These ops should support float32 and float64. |
140 |
| - self.assertIn(torch.float32, supported_dtypes) |
141 |
| - self.assertIn(torch.float64, supported_dtypes) |
142 |
| - |
143 |
| - @requires_gpu() |
144 |
| - @parametrize("op_name", OpDtypeSupport.supported_dtypes) |
145 |
| - @parametrize("load_upcast_to_fp32", [False, True]) |
146 |
| - @parametrize("input_dtype", [torch.float16, torch.bfloat16]) |
147 |
| - @config.patch("triton.use_block_ptr", True) |
148 |
| - def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtype): |
149 |
| - """ |
150 |
| - Test dtype aware codegen for some tl.math/libdevice calls. |
151 |
| - Operands should be upcast to float32, and the output should be downcast to float16. |
152 |
| - """ |
153 |
| - |
154 |
| - # Check if the op's output should be upcasted/downcasted. |
155 |
| - supported_dtypes = OpDtypeSupport.supported_dtypes[op_name] |
156 |
| - convert_output = OpDtypeSupport.convert_outputs[op_name] |
157 |
| - self.assertNotIn(input_dtype, supported_dtypes) |
158 |
| - |
159 |
| - # Retrieve the corresponding torch op. |
160 |
| - torch_op_name = op_name.removeprefix("libdevice_") |
161 |
| - op = getattr(torch, torch_op_name) |
162 |
| - |
163 |
| - # Edge case: torch.round maps to libdevice.nearbyint. |
164 |
| - triton_op_name_overrides = { |
165 |
| - "round": "nearbyint", |
166 |
| - } |
167 |
| - override = triton_op_name_overrides.get(op_name) |
168 |
| - triton_op_name = override if override is not None else torch_op_name |
169 |
| - |
170 |
| - # Get the number of args for the op. |
171 |
| - signatures = get_signature_for_torch_op(op) |
172 |
| - num_args = len(signatures[0].parameters) |
173 |
| - |
174 |
| - # Test codegen and check for casts. |
175 |
| - inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * num_args |
176 |
| - tl_dtype_str = str(input_dtype).replace("torch", "tl") |
177 |
| - with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32): |
178 |
| - compiled = torch._dynamo.optimize("inductor")(op) |
179 |
| - code = run_and_get_triton_code(compiled, *inps) |
180 |
| - |
181 |
| - # Search the code with a regex. |
182 |
| - # Example code: libdevice.floor(tmp3.to(tl.float32)).to(tl.float16) |
183 |
| - output_cast = rf"\.to\({tl_dtype_str}\)" if convert_output else "" |
184 |
| - pattern = rf"{triton_op_name}\(.*\.to\(tl\.float32\)\){output_cast}" |
185 |
| - cast_in_code = re.search(pattern, code, re.MULTILINE) is not None |
186 |
| - self.assertNotEqual(cast_in_code, load_upcast_to_fp32) |
187 |
| - |
188 |
| - @config.patch("triton.codegen_upcast_to_fp32", False) |
189 |
| - def test_binary_math_mixed_precision(self): |
190 |
| - """ |
191 |
| - Test a binary math operator where only one input needs to be upcast. |
192 |
| - """ |
193 |
| - # Create inputs of different dtypes. |
194 |
| - inputs = [ |
195 |
| - torch.randn(8, device=GPU_TYPE, dtype=dtype) |
196 |
| - for dtype in (torch.float16, torch.float32) |
197 |
| - ] |
198 |
| - |
199 |
| - func = torch.hypot |
200 |
| - compiled = torch.compile(backend="inductor")(func) |
201 |
| - result, (code,) = run_and_get_code(compiled, *inputs) |
202 |
| - |
203 |
| - # Check accuracy. |
204 |
| - ref = func(*inputs) |
205 |
| - self.assertTrue(torch.allclose(ref, result)) |
206 |
| - |
207 |
| - # Check for exactly one upcast. |
208 |
| - num_upcasts = code.count(".to(tl.float32)") |
209 |
| - self.assertEqual(num_upcasts, 1) |
210 |
| - |
211 |
| - # There should be no downcast, since the input is promoted to float32. |
212 |
| - self.assertNotIn(".to(tl.float16)", code) |
213 |
| - |
214 | 78 | @config.patch("test_configs.runtime_triton_dtype_assert", True)
|
215 | 79 | def test_constant(self):
|
216 | 80 | def fn():
|
|
0 commit comments