8000 Update on "[multigraph] use specializations in compile_and_call_fx_gr… · pytorch/pytorch@3fcd84b · GitHub
[go: up one dir, main page]

Skip to content

Commit 3fcd84b

Browse files
committed
Update on "[multigraph] use specializations in compile_and_call_fx_graph"
The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. There's really two parts of this work: **The frontend changes:** 1) we introduce an optional kwarg `specialize_on` to mark_{dynamic,unbacked} that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc. **The backend changes (this PR):** 1) We capture the backend_specialization specified in the mark_{dynamic,unbacked} API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py` 2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache. We do this by using a context manager to patch in specialization specific axioms into the ShapeEnv before invoking the user compiler. 3) When we have specializations, we install a lazy specialized dispatch function that checks each specialization and dispatches to the first one that matches. NB: instead of doing all of the specialization compiled up front, we do the compiles lazily. The first time a specialization is invoked, we will do the compilation and save it in a cache so subsequent invocations are fast. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards. I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions: ![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73) cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent 88aecdc commit 3fcd84b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

torch/_dynamo/repro/after_dynamo.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def add_paths(exc):
110110
# Check for either accuracy (level 4) or other type of failures.
111111
if config.repro_level == 4:
112112
# Check Accuracy
113-
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs, **kwargs)
113+
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
114114
if _accuracy_fails(gm, example_inputs, compiler_fn):
115115
log.warning(
116116
"Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error."
@@ -125,9 +125,7 @@ def add_paths(exc):
125125
raise exc
126126
else:
127127
try:
128-
compiled_gm = compiler_fn(
129-
copy.deepcopy(gm), example_inputs, **kwargs
130-
)
128+
compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
131129
run_fwd_maybe_bwd(compiled_gm, example_inputs)
132130
except Exception as exc:
133131
log.warning(
@@ -149,7 +147,7 @@ def add_paths(exc):
149147
add_paths(exc)
150148
raise
151149
else:
152-
compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
150+
compiled_gm = compiler_fn(gm, example_inputs)
153151

154152
return compiled_gm
155153

torch/_export/non_strict_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ def fakify(
143143
constraint_sizes[i] = RelaxedUnspecConstraint(warn_only=False) # type: ignore[call-overload]
144144
else:
145145
dynamic_sizes.append(DimDynamic.STATIC)
146-
symbolic_context = StatelessSymbolicContext(
147-
dynamic_sizes=dynamic_sizes,
148-
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
146+
symbolic_context: StatelessSymbolicContext = ( # make mypy happy
147+
StatelessSymbolicContext(
148+
dynamic_sizes=dynamic_sizes,
149+
constraint_sizes=constraint_sizes, # type: ignore[arg-type]
150+
)
149151
)
150152
t_id = id(t)
151153
assert mode.shape_env is not None

0 commit comments

Comments
 (0)
0