-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[multigraph] use backend specializations in compile_and_call_fx_graph #152601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ade77f1
c8e359b
87b796e
2fc599a
7b49125
2f887ce
49e8ee5
26915b7
b9a6bc5
24b506c
6a82fcc
e8eb158
c19b818
fc279ed
81621b3
994b1d4
415ea24
b119bb1
c2fb9cc
f0fb81e
becf7f6
00abea0
3e521df
6302ff8
ee88cb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 14256e6040d9e14698a877924456cdd92bfcd01d | ||
| 8eeef7f5b5363e9f35576184659226cc082311d6 | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10464,6 +10464,52 @@ def f(x): | |||||
| self.assertEqual(out_ref.stride(), out_test.stride()) | ||||||
| self.assertEqual(x_ref, x_test) | ||||||
|
|
||||||
| @requires_gpu() | ||||||
| @skip_if_not_triton | ||||||
| @unittest.skipIf( | ||||||
| not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" | ||||||
| ) | ||||||
| def test_inductor_multiple_specializations(self): | ||||||
| from triton.testing import do_bench | ||||||
|
|
||||||
| @torch.compile( | ||||||
| options={ | ||||||
| "max_autotune": True, | ||||||
| "max_autotune_gemm_backends": "TRITON", | ||||||
| }, | ||||||
| dynamic=False, | ||||||
| ) | ||||||
| def inductor_matmul(a, b): | ||||||
| torch._check(a.shape[0] == b.shape[1]) | ||||||
| return (m, torch.mm(a, b)) | ||||||
|
|
||||||
| m = 16 | ||||||
| k = 1280 | ||||||
| dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, The function is decorated with
Suggested change
|
||||||
| dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| torch._dynamo.decorators.mark_dynamic( | ||||||
| dynamic_a, | ||||||
| 0, | ||||||
| ) | ||||||
| torch._dynamo.decorators.mark_dynamic( | ||||||
bobrenjc93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| dynamic_specialized_a, | ||||||
| 0, | ||||||
| backend_specializations=[ | ||||||
| (16, lambda x0: x0 == 16), | ||||||
| ], | ||||||
|
Comment on lines
+10498
to
+10500
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have an api flow for when you want to specify conditions on multiple vars? E.g. You dont necessarily want to specialize on
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at the moment. vLLM actually only has one symbolic variable (https://www.anyscale.com/blog/continuous-batching-llm-inference) so we don't need to worry about that for our first customer. That being said, I'm happy to bikeshed what a better multi-var API may look like during composability. |
||||||
| ) | ||||||
| torch._dynamo.decorators.mark_dynamic( | ||||||
| b, | ||||||
| 1, | ||||||
| ) | ||||||
| dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b)) | ||||||
| torch._dynamo.reset() | ||||||
| dynamic_specialized = do_bench( | ||||||
| lambda: inductor_matmul(dynamic_specialized_a, b) | ||||||
| ) | ||||||
|
Comment on lines
+10508
to
+10510
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should check the output code |
||||||
| self.assertGreaterEqual(dynamic, dynamic_specialized) | ||||||
|
|
||||||
| @requires_gpu() | ||||||
| def test_stride_preservation_with_stride_modifying_fx_pass(self): | ||||||
| def f(x): | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,7 @@ | |
| import torch.nn | ||
| import torch.utils._pytree as pytree | ||
| from torch import fx | ||
| from torch._C._dynamo import guards | ||
| from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis | ||
| from torch._guards import ( | ||
| CompileContext, | ||
|
|
@@ -157,6 +158,8 @@ | |
| graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") | ||
| trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") | ||
|
|
||
| RootGuardManager = guards.RootGuardManager | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class VariableTrackerCacheKey: | ||
|
|
@@ -1496,8 +1499,34 @@ def compile_and_call_fx_graph(self, tx, rv, root): | |
| # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode | ||
| self.tracing_context.fake_mode = backend_fake_mode | ||
|
|
||
| specialized_compiles = [] | ||
| with self.restore_global_state(): | ||
| compiled_fn = self.call_user_compiler(gm) | ||
| sources = [a.source for a in self.graphargs] | ||
| for specialization in old_fake_mode.shape_env.backend_specializations: | ||
| source_index = sources.index(specialization.source) | ||
| check_fn_source = inspect.getsource(specialization.check_fn).strip() | ||
| check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] | ||
| specialization.check_fn, | ||
| [check_fn_source], | ||
| ) | ||
|
|
||
| log.debug( | ||
| "Compiling backend specialized graph with specialization=%s", | ||
| check_fn_source, | ||
| ) | ||
|
|
||
| specialized_compiles.append( | ||
| ( | ||
| functools.partial( | ||
| lambda idx, args, check_fn=check_fn: check_fn( | ||
| args[idx] | ||
| ), | ||
| source_index, | ||
| ), | ||
| self.call_user_compiler(gm, specialization=specialization), | ||
|
Comment on lines
+1506
to
+1527
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This calls the backend compiler with the (tensor_args,) for this graph and the specialization argument, right? I'm not sure this is the right design. The (tensor_args,) don't have the same shape as the specialization -- will that be a problem? An alternative design is that there is some lazy dispatching layer right after Dynamo but before AOTAutograd. Let's say the user calls the following for the first time: Then this traces out a graph from Dynamo with dynamic shapes. Then, on future calls to torch.compile:
One way to implement this is:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The benefit of the alternative lazy design is that the backend doesn't need to work hard to figure out how to do the specialization: it's almost like calling regular torch.compile again, except it is able to skip Dynamo. One side effect is that we don't have to impose constraints on the strides (this PR needs to do that because it needs to figure out how to create a FakeTensor, right?)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this makes sense. cc @anijain2305 for thoughts as well
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few details that we need to think about
Maybe we have the bytecode that calls the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @anijain2305 thoughts on #153449 ?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the issue with the current implementation? Its not bad. It gives the hierarchical feel, which kind of makes sense in this case.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider the following code: x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)
x = torch.randn(1)
torch.compile(f)(x)
x = torch.randn(2)
torch.compile(f)(x)On the first torch.compile call, we will attempt to compile all of the backend specializations. That torch.compile call only has one set of sample inputs (of shape [3]). The problems I'm worried about is: The lazier design (#153449) solves this by (a) deferring compilation of shape [1] and shape [2] until we actually see inputs of those shapes and (b) if the strides change then it's a recompile |
||
| ) | ||
| ) | ||
|
|
||
| from torch.fx._lazy_graph_module import _LazyGraphModule | ||
|
|
||
|
|
@@ -1528,7 +1557,18 @@ def compile_and_call_fx_graph(self, tx, rv, root): | |
|
|
||
| counters["stats"]["unique_graphs"] += 1 | ||
| # This is safe because we pre-process name to be unique | ||
| self.install_global_unsafe(name, compiled_fn) | ||
| if specialized_compiles: | ||
|
|
||
| @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") | ||
| def specialized_dispatch(*args, **kwargs): | ||
| for check_fn, specialized_compiled_fn in specialized_compiles: | ||
| if check_fn(args): | ||
| return specialized_compiled_fn(*args, **kwargs) | ||
| return compiled_fn(*args, **kwargs) | ||
|
|
||
| self.install_global_unsafe(name, specialized_dispatch) | ||
| else: | ||
| self.install_global_unsafe(name, compiled_fn) | ||
|
|
||
| cg = PyCodegen(tx) | ||
| cg.make_call_generated_code(name) | ||
|
|
@@ -1542,16 +1582,16 @@ def placeholders(self) -> list[fx.Node]: | |
| def graphargs(self) -> list[GraphArg]: | ||
| return [node.meta["grapharg"] for node in self.placeholders] | ||
|
|
||
| def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | ||
| def call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn: | ||
| with dynamo_timed( | ||
| "OutputGraph.call_user_compiler", | ||
| phase_name="backend_compile", | ||
| log_pt2_compile_event=True, | ||
| dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", | ||
| ): | ||
| return self._call_user_compiler(gm) | ||
| return self._call_user_compiler(gm, **kwargs) | ||
|
|
||
| def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | ||
| def _call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn: | ||
| assert self.compiler_fn is not None | ||
| tot = 0 | ||
| placeholders = [] | ||
|
|
@@ -1581,7 +1621,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | |
| compiler_fn = self.compiler_fn | ||
| if config.verify_correctness: | ||
| compiler_fn = WrapperBackend(compiler_fn) | ||
| compiled_fn = compiler_fn(gm, self.example_inputs()) | ||
| compiled_fn = compiler_fn(gm, self.example_inputs(), **kwargs) | ||
| _step_logger()(logging.INFO, f"done compiler function {name}") | ||
| assert callable(compiled_fn), "compiler_fn did not return callable" | ||
| except (TensorifyScalarRestartAnalysis, ShortenTraceback): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intentional?