8000 [Inductor] Subgraph support dynamic input expressions by PaulZhang12 · Pull Request #153754 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Subgraph support dynamic input expressions #153754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/PaulZhang12/15/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import unittest
from typing import Callable, Optional
from unittest import mock

import torch
from torch import multiprocessing as mp, nn
Expand Down Expand Up @@ -993,6 +994,54 @@ def check_divisors(code):
bf16_red_setting
)

@skipIfXpu
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
)
def test_max_autotune_decompose_k_dynamic_input(self):
def f(a, b):
a_in = torch.stack((a, a), dim=0)
return (a_in @ b).relu()

a = torch.randn(
32, 32768, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
b = torch.randn(
32768, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True
)

torch._dynamo.reset()
torch._dynamo.maybe_mark_dynamic(a, 0)
compiled_func = torch.compile(f)

with mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decomp_mock:
decomp_mock.return_value = True

out, code = run_and_get_code(compiled_func, a, b)
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").check_regex("s[0-9]+ = primals_1").check_regex(
"2*s[0-9]+"
).check(
"primals_1 = 32"
).run(
code[0]
)
torch.testing.assert_close(
out,
f(a, b),
atol=1e-2,
rtol=1e-2,
)


class TestMaxAutotunePrecompile(TestCase):
def test_precompilation_threads(self):
Expand Down
35 changes: 25 additions & 10 deletions torch/_inductor/codegen/subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import torch
from torch._inductor import ir
from torch._inductor.codegen.common import KernelTemplate
from torch._inductor.ir import Buffer, ir_node_to_tensor, Layout
from torch._inductor.ir import (
add_symbolic_shapes_for_inputs_to_subgraph,
Buffer,
ir_node_to_tensor,
Layout,
)
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.virtualized import V

Expand Down Expand Up @@ -59,12 +64,23 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
name=f"benchmark_{self.name}",
)

for ar, example_inp in zip(args, self.example_inputs):
sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph(
self.example_inputs, bm_graph_lowering
)
sym_inputs = [
int(V.graph.sizevars.shape_env.size_hint(sym_var)) for sym_var in sym_inputs
]

if len(sym_inputs) == 0:
# Sanity check that args are same layout as example inputs
if isinstance(ar, torch.Tensor):
assert isinstance(example_inp, torch.Tensor)
assert ar.shape == example_inp.shape
assert ar.stride() == example_inp.stride()
# Only do it if there are no symbolic inputs, otherwise
# the dynamic dim will be realized to the same size as args
for ar, example_inp in zip(args, self.example_inputs):
# Sanity check that args are same layout as example inputs
if isinstance(ar, torch.Tensor):
assert isinstance(example_inp, torch.Tensor)
assert ar.shape == example_inp.shape
assert ar.stride() == example_inp.stride()

with V.set_graph_handler(bm_graph_lowering):
# Don't bother autotuning on Triton here
Expand All @@ -73,13 +89,12 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
max_autotune_gemm=False,
max_autotune_gemm_backends="ATEN",
):
bm_graph_lowering.run(*args)
bm_graph_lowering.run(*self.example_inputs)
mod = bm_graph_lowering.compile_to_module()
bm_func = mod.call

bm_func([*args])

return benchmarker.benchmark_gpu(lambda: bm_func([*args]))
bm_func([*sym_inputs, *args])
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))

def hash_key(self) -> str:
return "-".join(
Expand Down
35 changes: 33 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,32 @@ def try_match_insignificant_strides(
return TensorBox(ReinterpretView(data=storage, layout=new_layout))


def add_symbolic_shapes_for_inputs_to_subgraph(
inputs: list[Any], subgraph: GraphLowering
) -> list[Expr]:
sym_vars: OrderedSet[Expr] = OrderedSet()
for inp in inputs:
if isinstance(inp, torch.Tensor):
for size in inp.size():
if isinstance(size, SymTypes):
sym_vars |= size.node.expr.free_symbols
for stride in inp.stride():
if isinstance(stride, SymTypes):
sym_vars |= stride.node.expr.free_symbols

sym_inputs = []
for sym_var in sym_vars:
assert sym_var in V.graph.graph_inputs.values()

for inp in V.graph.graph_inputs:
if V.graph.graph_inputs[inp] == sym_var:
subgraph.graph_inputs[inp] = sym_var
subgraph.graph_input_names.append(inp)
sym_inputs.append(sym_var)

return sym_inputs


class IRNode:
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()

Expand Down Expand Up @@ -6091,6 +6117,12 @@ def __init__(
self.subgraph = V.graph.make_subgraph(
self.gm, self.example_inputs, subgraph_name
)

sym_inputs = add_symbolic_shapes_for_inputs_to_subgraph(
self.example_inputs, self.subgraph
)
self.sym_inputs = [sym_var.name for sym_var in sym_inputs]

import torch._inductor.config as inductor_config

with V.set_graph_handler(self.subgraph):
Expand All @@ -6109,10 +6141,9 @@ def __init__(self, graph: GraphLowering):
self.name = graph.name

outer_inputs = [t.codegen_reference() for t in self.inputs]

wrapper.codegen_subgraph_with_flattened_outputs(
CodegenGraph(self.subgraph),
outer_inputs,
[*self.sym_inputs, *outer_inputs],
[self.name],
)

Expand Down
Loading
0