8000 Reland: [inductor] Simplify grid handling (#148305) · pytorch/pytorch@b040dc3 · GitHub
[go: up one dir, main page]

Skip to content

Commit b040dc3

Browse files
janselpytorchmergebot
authored andcommitted
Reland: [inductor] Simplify grid handling (#148305)
Summary: Relands D69965761 / #147583 Before this PR, calling a triton kernel would look like: ```py kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0) ``` where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg: ```py kernel.run(a, b, xnumel, stream=stream0) ``` instead now the grid computation is included in the kernel launcher, with something like: ```py def launcher(in_ptr0, out_ptr0, xnumel, stream): grid_0 = ((xnumel + 1023) >> 10) grid_1 = 1 grid_2 = 1 runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel) ``` This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`. It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid. This unification allows this PR to be a net deletion of code. Differential [disconnected] Revision: D70471332 Pull Request resolved: #148305 Approved by: https://github.com/shunting314, https://github.com/eellison
1 parent 626a5e2 commit b040dc3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+986
-1312
lines changed

test/inductor/test_aot_inductor.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4086,10 +4086,9 @@ def forward(self, x):
40864086
# input u0 was defined as int32_t initially, verify for every kernel var args downstream,
40874087
# it gets explicitly declared using its data types in the cpp wrapper codegen code.
40884088
expected_scalar_args = [
4089-
"int64_t var_1 = u0;",
4090-
"int64_t var_4 = u0;",
4091-
"int64_t var_7 = u0;",
4092-
"int64_t var_12 = u0;",
4089+
"buf3, u0",
4090+
"buf4, u0",
4091+
"buf3, buf4, buf2, u0",
40934092
]
40944093
# check the new behavior of codegen is expected
40954094
result, code = run_and_get_cpp_code(

test/inductor/test_cpp_wrapper_hipify.py

-16
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,6 @@ def test_hipify_aoti_driver_header(self) -> None:
5454
} \\
5555
} while (0);
5656
57-
namespace {
58-
59-
struct Grid {
60-
Grid(uint32_t x, uint32_t y, uint32_t z)
61-
: grid_x(x), grid_y(y), grid_z(z) {}
62-
uint32_t grid_x;
63-
uint32_t grid_y;
64-
uint32_t grid_z;
65-
66-
bool is_non_zero() {
67-
return grid_x > 0 && grid_y > 0 && grid_z > 0;
68-
}
69-
};
70-
71-
} // anonymous namespace
72-
7357
static inline hipFunction_t loadKernel(
7458
std::string filePath,
7559
const std::string &funcName,

test/inductor/test_cuda_repro.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def test_autotune_inplace_kernel(self):
550550
"""
551551
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
552552
from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType
553-
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
553+
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
554554
from torch._inductor.utils import triton_version_uses_attrs_dict
555555

556556
def autotune(configs, meta):
@@ -570,6 +570,7 @@ def decorator(fn):
570570
reset_to_zero_arg_names=[],
571571
optimize_mem=True,
572572
heuristic_type=HeuristicType.POINTWISE,
573+
inductor_meta={"grid_type": "Grid1D"},
573574
)
574575

575576
return decorator
@@ -609,8 +610,8 @@ def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):
609610
inout2 = inout1.clone()
610611

611612
stream0 = get_cuda_stream(0)
612-
kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
613-
kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
613+
kernel.run(inout1, in0, xnumel, stream=stream0)
614+
kernel.run(inout2, in0, xnumel, stream=stream0)
614615

615616
assert same(
616617
inout1, inout2, tol=0.001, equal_nan=True

test/inductor/test_kernel_benchmark.py

+22-33
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,16 @@ def get_compiled_module(self):
5858
def verify_compiled_kernels(self, GB_count=1):
5959
compiled_module = self.get_compiled_module()
6060
# now run the compiled module in subprocess and check its output
61-
bench_out = subprocess.check_output(
62-
f"{sys.executable} {compiled_module.__file__} -kc".split(),
63-
stderr=subprocess.STDOUT,
64-
env={**os.environ, "PYTHONPATH": self.python_path},
65-
).decode()
61+
try:
62+
bench_out = subprocess.check_output(
63+
f"{sys.executable} {compiled_module.__file__} -kc".split(),
64+
stderr=subprocess.STDOUT,
65+
env={**os.environ, "PYTHONPATH": self.python_path},
66+
).decode()
67+
except subprocess.CalledProcessError as e:
68+
print("Failed when running output code", e)
69+
print(e.output.decode())
70+
raise e
6671

6772
# make sure we have the bandwidth information in the output
6873
FileCheck().check_count(
@@ -111,11 +116,16 @@ def verify_remove_inductor_deps(self, compiled_module):
111116

112117
def check_bandwidth(self, compiled_module, num_gb):
113118
# now run the compiled module in subprocess and check its output
114-
bench_out = subprocess.check_output(
115-
f"{sys.executable} {compiled_module.__file__} -k".split(),
116-
stderr=subprocess.STDOUT,
117-
env={**os.environ, "PYTHONPATH": self.python_path},
118-
).decode()
119+
try:
120+
bench_out = subprocess.check_output(
121+
f"{sys.executable} {compiled_module.__file__} -k".split(),
122+
stderr=subprocess.STDOUT,
123+
env={**os.environ, "PYTHONPATH": self.python_path},
124+
).decode()
125+
except subprocess.CalledProcessError as e:
126+
print("Failed when running output code", e)
127+
print(e.output.decode())
128+
raise e
119129

120130
# make sure we have the bandwidth information in the output
121131
FileCheck().check_count(
@@ -154,7 +164,7 @@ def f(a, b):
154164
self.verify_compiled_kernels()
155165

156166
@config.patch(
157-
max_autotune=True, max_autotune_gemm_backends="TRITON", force_shape_pad=True
167+
max_autotune=True, max_autotune_gemm_backends="TRITON", shape_padding=False
158168
)
159169
@fresh_inductor_cache()
160170
def test_mm_triton_kernel_benchmark(self):
@@ -173,28 +183,7 @@ def f(a, b):
173183

174184
f(a, b)
175185

176-
GB_count = 3
177-
# pad_mm is not enabled on XPU, so there is only one kernel.
178-
if GPU_TYPE == "xpu":
179-
GB_count = 1
180-
self.verify_compiled_kernels(GB_count=GB_count)
181-
182-
# make sure we correctly generate the grid info
183-
compiled_module = self.get_compiled_module()
184-
with open(compiled_module.__file__) as f:
185-
source_code = f.read()
186-
lines = source_code.split("\n")
187-
meta = [l for l in lines if "meta0 = {" in l]
188-
scope = {}
189-
from torch._inductor.kernel.mm_common import mm_grid
190-
191-
exec(meta[0], scope)
192-
grid = mm_grid(M, N, scope["meta0"])
193-
FileCheck().check_count(
194-
f"grid={grid}",
195-
2,
196-
exactly=1,
197-
).run(source_code)
186+
self.verify_compiled_kernels(GB_count=1)
198187

199188
def test_matmul_bandwidth_computation(self):
200189
"""

test/inductor/test_max_autotune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _get_func_call() -> str:
5757

5858

5959
def _get_kernel_launch() -> str:
60-
return "launchKernel(" if config.cpp_wrapper else ".run("
60+
return "call_triton_" if config.cpp_wrapper else ".run("
6161

6262

6363
def benchmark_choice(choice, args, out, expected_out, timings):

test/inductor/test_profiler.py

-3
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,6 @@ def check_triton_event(e) -> None:
263263
self.assertEqual(args["kernel_backend"], "triton", msg=f"event = {e}")
264264

265265
self.assertTrue("stream" in args, msg=f"event = {e}")
266-
self.assertTrue("grid" in args, msg=f"event = {e}")
267-
self.assertTrue(args["grid"].startswith("grid"), msg=f"event = {e}")
268-
269266
self.assertTrue("kernel_file" in args, msg=f"event = {e}")
270267
kernel_file = args["kernel_file"]
271268
self.assertTrue(os.path.isfile(kernel_file), msg=f"event = {e}")

test/inductor/test_select_algorithm.py

-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ def test_TritonTemplateCaller_str(self):
354354
module_path=module_path,
355355
module_cache_key=None,
356356
kernel_name=None,
357-
grid=None,
358357
extra_args=None,
359358
num_stages=None,
360359
num_warps=None,

test/inductor/test_torchinductor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4001,7 +4001,7 @@ def foo(m, inp):
40014001
_, code = run_and_get_code(foo, grouped_conv, input_tensor)
40024002
# no to channels last permuting before kernel
40034003
if config.cpp_wrapper:
4004-
FileCheck().check_not("launchKernel(triton").check("_convolution(").run(
4004+
FileCheck().check_not(" call_triton").check("_convolution(").run(
40054005
code[0]
40064006
)
40074007
else:

test/inductor/test_triton_kernels.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3635,8 +3635,10 @@ def grid(META):
36353635
output = "\n".join(record.getMessage() for record in log.records)
36363636
# correct grid example values updated per block size
36373637
FileCheck().check("Compile-time auto-tuning block:").check(
3638-
"grid_wrapper_for_op_zeros_0"
3639-
).check_next("return (256").check_next("return (64").run(output)
3638+
"PrecomputedGrid"
3639+
).check("(31 + _launcher_s0) // 32").check("(127 + _launcher_s0) // 128").run(
3640+
output
3641+
)
36403642

36413643
# Triton 3.2.0 adds the required flags to the Autotuner object for this test
36423644
# PR: https://github.com/triton-lang/triton/pull/5092

torch/_inductor/autotune_process.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,6 @@ def __init__(
639639
extra_args: Iterable[Any],
640640
module_path: str, # the path of the module defining the triton kernel
641641
module_cache_key: str,
642-
grid: list[int],
643642
num_stages: int,
644643
num_warps: int,
645644
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
@@ -650,7 +649,6 @@ def __init__(
650649
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
651650
self.module_path = module_path
652651
self.module_cache_key = module_cache_key
653-
self.grid = grid
654652
self.num_stages = num_stages
655653
self.num_warps = num_warps
656654
self.matrix_instr_nonkdim = matrix_instr_nonkdim
@@ -704,16 +702,15 @@ def run_with_workspace():
704702
)
705703

706704
# Handle zero initialization if needed
707-
if workspace_arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL:
705+
if workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED:
708706
workspace_tensor.zero_()
709707

710708
# Run the kernel with workspace
711709
run_method(
712710
*input_tensors,
713711
output_tensor,
714-
*extra_args,
715712
workspace_tensor,
716-
grid=self.grid,
713+
*extra_args,
717714
**warmup_arg,
718715
stream=stream,
719716
benchmark_run=True,
@@ -729,7 +726,6 @@ def run_with_workspace():
729726
*input_tensors,
730727
output_tensor,
731728
*extra_args,
732-
grid=self.grid,
733729
**warmup_arg,
734730
stream=stream,
735731
)
@@ -739,7 +735,6 @@ def run_with_workspace():
739735
*input_tensors,
740736
output_tensor,
741737
*extra_args,
742-
grid=self.grid,
743738
**warmup_arg,
744739
stream=stream,
745740
benchmark_run=True,

torch/_inductor/codecache.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
13641364

13651365
@clear_on_fresh_inductor_cache
13661366
class CudaKernelParamCache:
1367-
cache: dict[str, dict[str, str]] = {}
1367+
cache: dict[str, dict[str, Any]] = {}
13681368
cache_clear = staticmethod(cache.clear)
13691369

13701370
@classmethod
@@ -1382,7 +1382,7 @@ def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> Non
13821382
cls.cache[key] = params
13831383

13841384
@classmethod
1385-
def get(cls, key: str) -> Optional[dict[str, str]]:
1385+
def get(cls, key: str) -> Optional[dict[str, Any]]:
13861386
return cls.cache.get(key, None)
13871387

13881388
@classmethod

torch/_inductor/codegen/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def cpp_argdefs(self) -> tuple[list[str], list[str], list[str]]:
15501550

15511551
def python_argdefs(
15521552
self,
1553-
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[torch.dtype]]:
1553+
) -> tuple[list[ArgName], list[str], list[KernelArgType], list[Any]]:
15541554
arg_defs: list[ArgName] = []
15551555
call_args: list[str] = []
15561556
arg_types: list[torch.dtype] = []

torch/_inductor/codegen/cpp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5209,7 +5209,7 @@ def codegen_group(self, name=None) -> str:
52095209
def call_kernel(self, wrapper, kernel_name):
52105210
_, call_args, arg_types = self.args.cpp_argdefs()
52115211
wrapper.generate_kernel_call(
5212-
kernel_name, call_args, gpu=False, triton=False, arg_types=arg_types
5212+
kernel_name, call_args, triton=False, arg_types=arg_types
52135213
)
52145214

52155215

torch/_inductor/codegen/cpp_template_kernel.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ def hook():
118118
def call_kernel(self, name: str, node: ir.CppTemplateBuffer):
119119
wrapper = V.graph.wrapper_code
120120
_, call_args, arg_types = self.args.cpp_argdefs()
121-
wrapper.generate_kernel_call(
122-
name, call_args, triton=False, gpu=False, arg_types=arg_types
123-
)
121+
wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types)
124122

125123
def dtype(self, node: ir.Buffer) -> str:
126124
return DTYPE_TO_CPP[node.get_dtype()]

0 commit comments

Comments
 (0)
0