8000 [multigraph] use specializations in compile_and_call_fx_graph by bobrenjc93 · Pull Request #153449 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] use specializations in compile_and_call_fx_graph #153449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/bobrenjc93/343/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10464,6 +10464,50 @@ def f(x):
self.assertEqual(out_ref.stride(), out_test.stride())
self.assertEqual(x_ref, x_test)

@requires_gpu()
@skip_if_not_triton
@unittest.skipIf(
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
)
def test_inductor_multiple_specializations(self):
from triton.testing import do_bench

@torch.compile(
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
dynamic=False,
)
def inductor_matmul(a, b):
torch._check(a.shape[0] == b.shape[1])
return (m, torch.mm(a, b))

m = 16
k = 1280
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
torch._dynamo.decorators.mark_dynamic(
dynamic_a,
0,
)
torch._dynamo.decorators.mark_dynamic(
dynamic_specialized_a,
0,
specialize_on=[lambda x0: x0 == 16],
)
torch._dynamo.decorators.mark_dynamic(
b,
1,
)
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
torch._dynamo.reset()
dynamic_specialized = do_bench(
lambda: inductor_matmul(dynamic_specialized_a, b)
)
self.assertGreaterEqual(dynamic, dynamic_specialized)

@requires_gpu()
def test_stride_preservation_with_stride_modifying_fx_pass(self):
def f(x):
Expand Down
4 changes: 2 additions & 2 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
)
self.config[attr_name] = val

def __call__(self, model_, inputs_):
def __call__(self, model_, inputs_, **kwargs):
from torch._inductor.compile_fx import compile_fx

return compile_fx(model_, inputs_, config_patches=self.config)
return compile_fx(model_, inputs_, config_patches=self.config, **kwargs)

def get_compiler_config(self):
from torch._inductor.compile_fx import get_patched_config_dict
Expand Down
84 changes: 73 additions & 11 deletions torch/_dynamo/output_graph.py
< 9E7A /table>
4 changes: 4 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import torch.distributed as dist
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
from torch import fx, Tensor
from torch._C._dynamo import guards
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
from torch._guards import (
CompileContext,
Expand All @@ -61,6 +62,7 @@
guard_scalar,
is_symbolic,
ShapeEnv,
Specialization,
)
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch.multiprocessing.reductions import StorageWeakRef
Expand Down Expand Up @@ -157,6 +159,8 @@
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")

RootGuardManager = guards.RootGuardManager


@dataclass(frozen=True)
class VariableTrackerCacheKey:
Expand Down Expand Up @@ -1497,7 +1501,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
self.tracing_context.fake_mode = backend_fake_mode

with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
compiled_fn = self.call_user_compiler(gm, self.example_inputs())

from torch.fx._lazy_graph_module import _LazyGraphModule

Expand Down Expand Up @@ -1528,7 +1532,60 @@ def compile_and_call_fx_graph(self, tx, rv, root):

counters["stats"]["unique_graphs"] += 1
# This is safe because we pre-process name to be unique
self.install_global_unsafe(name, compiled_fn)
if specializations := old_fake_mode.shape_env.specializations:
specialization_guards = []
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
sources = [a.source for a in self.graphargs]
for specialization in specializations:
source_index = sources.index(specialization.source)
check_fn_source = inspect.getsource(specialization.check_fn).strip()
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
specialization.check_fn,
[check_fn_source],
)

log.debug(
"Compiling backend specialized graph with specialization=%s",
check_fn_source,
)

specialization_guards.append(
(
functools.partial(
lambda idx, args, check_fn=check_fn: check_fn(
args[idx]
),
source_index,
),
specialization,
)
)

@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph")
def specialized_dispatch(*args, **kwargs):
for check_fn, specialization in specialization_guards:
if check_fn(args):
if specialization in specialization_cache:
return specialization_cache[specialization](
*args, **kwargs
)

with self.shape_env.patch_source_specialization(
specialization.source, specialization.check_fn
):
# Modify gm so AOTAutogradCache key changes per specialization
gm.meta["specialization"] = specialization
example_inputs: list[Tensor] = list(args)
specialization_cache[specialization] = (
self.call_user_compiler(gm, example_inputs)
)

return specialization_cache[specialization](*args, **kwargs)
return compiled_fn(*args, **kwargs)

self.install_global_unsafe(name, specialized_dispatch)
else:
self.install_global_unsafe(name, compiled_fn)

cg = PyCodegen(tx)
cg.make_call_generated_code(name)
Expand All @@ -1542,16 +1599,20 @@ def placeholders(self) -> list[fx.Node]:
def graphargs(self) -> list[GraphArg]:
return [node.meta["grapharg"] for node in self.placeholders]

def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def call_user_compiler(
self, gm: fx.GraphModule, example_inputs: list[Tensor]
) -> CompiledFn:
with dynamo_timed(
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
log_pt2_compile_event=True,
dynamo_compile_column_u 6D4E s="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)
return self._call_user_compiler(gm, example_inputs)

def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def _call_user_compiler(
self, gm: fx.GraphModule, example_inputs: list[Tensor]
) -> CompiledFn:
assert self.compiler_fn is not None
tot = 0
placeholders = []
Expand All @@ -1562,10 +1623,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
placeholders.append(node)
increment_op_count(tot)
for pl in placeholders:
arg = pl.meta["grapharg"]
# TODO: Why isn't this stored in meta :think:
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
pl._dynamo_source = arg.source
if not hasattr(pl, "_dynamo_source"):
arg = pl.meta["grapharg"]
# TODO: Why isn't this stored in meta :think:
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
pl._dynamo_source = arg.source

# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
Expand All @@ -1581,7 +1643,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
compiler_fn = self.compiler_fn
if config.verify_correctness:
compiler_fn = WrapperBackend(compiler_fn)
compiled_fn = compiler_fn(gm, self.example_inputs())
compiled_fn = compiler_fn(gm, example_inputs)
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except (TensorifyScalarRestartAnalysis, ShortenTraceback):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3053,6 +3053,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides = []
constraint_sizes = []
constraint_strides = []
specialize_on = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_strict_unbacked = i in getattr(
Expand All @@ -3063,6 +3064,8 @@ def update_dim2constraint(dim, constraint_range, name):
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

specialize_on.append(getattr(e, "_specialize_on", {}).get(i, []))

# Reflect the user directive in the frame_state
# For dynamic, apply None always

Expand Down Expand Up @@ -3182,6 +3185,7 @@ def update_dim2constraint(dim, constraint_range, name):
dynamic_strides=dynamic_strides,
constraint_sizes=constraint_sizes,
constraint_strides=constraint_strides,
specialize_on=specialize_on,
view_base_context=view_base_context,
tensor_source=source,
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache,
Expand Down
8 changes: 5 additions & 3 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def fakify(
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
else:
dynamic_sizes.append(DimDynamic.STATIC)
symbolic_context = StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
)
)
t_id = id(t)
assert mode.shape_env is not None
Expand Down
Loading
Loading
0