|
1 | 1 | # Owner(s): ["module: inductor"]
|
2 | 2 | import importlib
|
3 | 3 | import os
|
| 4 | +import re |
4 | 5 | import sys
|
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 | from torch._dynamo.utils import disable_cache_limit
|
8 | 9 | from torch._inductor import config
|
| 10 | +from torch._inductor.codegen.triton import OpDtypeSupport |
9 | 11 | from torch._inductor.test_case import TestCase as InductorTestCase
|
10 |
| -from torch._inductor.utils import run_and_get_code |
| 12 | +from torch._inductor.utils import run_and_get_code, run_and_get_triton_code |
| 13 | +from torch.fx.operator_schemas import get_signature_for_torch_op |
11 | 14 | from torch.testing import FileCheck
|
12 | 15 | from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
13 | 16 | from torch.testing._internal.common_methods_invocations import op_db
|
| 17 | +from torch.testing._internal.common_utils import parametrize |
| 18 | +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu |
14 | 19 |
|
15 | 20 |
|
16 | 21 | # Make the helper files in test/ importable
|
@@ -75,6 +80,137 @@ def run(op, args, kwargs):
|
75 | 80 | out_c = torch.compile(run)(op.get_op(), args, kwargs)
|
76 | 81 | self.assertEqual(out, out_c)
|
77 | 82 |
|
| 83 | + @requires_gpu() |
| 84 | + @parametrize("upcast_to_fp32", [False, True]) |
| 85 | + @config.patch("triton.use_block_ptr", True) |
| 86 | + def test_codegen_upcast_to_fp32(self, upcast_to_fp32): |
| 87 | + @torch.compile |
| 88 | + def func(a, b, c, d): |
| 89 | + return a * b * c * d |
| 90 | + |
| 91 | + inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4 |
| 92 | + with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32): |
| 93 | + func_opt = torch._dynamo.optimize("inductor")(func) |
| 94 | + code = run_and_get_triton_code(func_opt, *inps) |
| 95 | + fp32_cast_in_code = "to(tl.float32)" in code |
| 96 | + self.assertEqual(fp32_cast_in_code, upcast_to_fp32) |
| 97 | + |
| 98 | + def test_op_dtype_support(self): |
| 99 | + """ |
| 100 | + Triton codegen upcasts values to float32 for certain ops. |
| 101 | + Check that those ops have accurate dtype information. |
| 102 | + """ |
| 103 | + |
| 104 | + for op_name in [ |
| 105 | + "rsqrt", |
| 106 | + "sqrt", |
| 107 | + "isnan", |
| 108 | + "floor", |
| 109 | + "ceil", |
| 110 | + "tan", |
| 111 | + "atan", |
| 112 | + "atanh", |
| 113 | + "sigmoid", |
| 114 | + "log2", |
| 115 | + "log10", |
| 116 | + "cosh", |
| 117 | + "sinh", |
| 118 | + "acosh", |
| 119 | + "asinh", |
| 120 | + "asin", |
| 121 | + "acos", |
| 122 | + "asinh", |
| 123 | + "erf", |
| 124 | + "lgamma", |
| 125 | + "sin", |
| 126 | + "cos", |
| 127 | + "exp", |
| 128 | + "expm1", |
| 129 | + "exp2", |
| 130 | + "abs", |
| 131 | + "hypot", |
| 132 | + "nextafter", |
| 133 | + ]: |
| 134 | + # These ops do not support float16 and bfloat16. |
| 135 | + supported_dtypes = OpDtypeSupport.supported_dtypes[op_name] |
| 136 | + self.assertNotIn(torch.float16, supported_dtypes) |
| 137 | + self.assertNotIn(torch.bfloat16, supported_dtypes) |
| 138 | + |
| 139 | + # These ops should support float32 and float64. |
| 140 | + self.assertIn(torch.float32, supported_dtypes) |
| 141 | + self.assertIn(torch.float64, supported_dtypes) |
| 142 | + |
| 143 | + @requires_gpu() |
| 144 | + @parametrize("op_name", OpDtypeSupport.supported_dtypes) |
| 145 | + @parametrize("load_upcast_to_fp32", [False, True]) |
| 146 | + @parametrize("input_dtype", [torch.float16, torch.bfloat16]) |
| 147 | + @config.patch("triton.use_block_ptr", True) |
| 148 | + def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtype): |
| 149 | + """ |
| 150 | + Test dtype aware codegen for some tl.math/libdevice calls. |
| 151 | + Operands should be upcast to float32, and the output should be downcast to float16. |
| 152 | + """ |
| 153 | + |
| 154 | + # Check if the op's output should be upcasted/downcasted. |
| 155 | + supported_dtypes = OpDtypeSupport.supported_dtypes[op_name] |
| 156 | + convert_output = OpDtypeSupport.convert_outputs[op_name] |
| 157 | + self.assertNotIn(input_dtype, supported_dtypes) |
| 158 | + |
| 159 | + # Retrieve the corresponding torch op. |
| 160 | + torch_op_name = op_name.removeprefix("libdevice_") |
| 161 | + op = getattr(torch, torch_op_name) |
| 162 | + |
| 163 | + # Edge case: torch.round maps to libdevice.nearbyint. |
| 164 | + triton_op_name_overrides = { |
| 165 | + "round": "nearbyint", |
| 166 | + } |
| 167 | + override = triton_op_name_overrides.get(op_name) |
| 168 | + triton_op_name = override if override is not None else torch_op_name |
| 169 | + |
| 170 | + # Get the number of args for the op. |
| 171 | + signatures = get_signature_for_torch_op(op) |
| 172 | + num_args = len(signatures[0].parameters) |
| 173 | + |
| 174 | + # Test codegen and check for casts. |
| 175 | + inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * num_args |
| 176 | + tl_dtype_str = str(input_dtype).replace("torch", "tl") |
| 177 | + with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32): |
| 178 | + compiled = torch._dynamo.optimize("inductor")(op) |
| 179 | + code = run_and_get_triton_code(compiled, *inps) |
| 180 | + |
| 181 | + # Search the code with a regex. |
| 182 | + # Example code: libdevice.floor(tmp3.to(tl.float32)).to(tl.float16) |
| 183 | + output_cast = rf"\.to\({tl_dtype_str}\)" if convert_output else "" |
| 184 | + pattern = rf"{triton_op_name}\(.*\.to\(tl\.float32\)\){output_cast}" |
| 185 | + cast_in_code = re.search(pattern, code, re.MULTILINE) is not None |
| 186 | + self.assertNotEqual(cast_in_code, load_upcast_to_fp32) |
| 187 | + |
| 188 | + @config.patch("triton.codegen_upcast_to_fp32", False) |
| 189 | + def test_binary_math_mixed_precision(self): |
| 190 | + """ |
| 191 | + Test a binary math operator where only one input needs to be upcast. |
| 192 | + """ |
| 193 | + # Create inputs of different dtypes. |
| 194 | + inputs = [ |
| 195 | + torch.randn(8, device=GPU_TYPE, dtype=dtype) |
| 196 | + for dtype in (torch.float16, torch.float32) |
| 197 | + ] |
| 198 | + |
| 199 | + func = torch.hypot |
| 200 | + compiled = torch.compile(backend="inductor")(func) |
| 201 | + result, (code,) = run_and_get_code(compiled, *inputs) |
| 202 | + |
| 203 | + # Check accuracy. |
| 204 | + ref = func(*inputs) |
| 205 | + self.assertTrue(torch.allclose(ref, result)) |
| 206 | + |
| 207 | + # Check for exactly one upcast. |
| 208 | + num_upcasts = code.count(".to(tl.float32)") |
| 209 | + self.assertEqual(num_upcasts, 1) |
| 210 | + |
| 211 | + # There should be no downcast, since the input is promoted to float32. |
| 212 | + self.assertNotIn(".to(tl.float16)", code) |
| 213 | + |
78 | 214 | @config.patch("test_configs.runtime_triton_dtype_assert", True)
|
79 | 215 | def test_constant(self):
|
80 | 216 | def fn():
|
|
0 commit comments