-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[multigraph] use specializations in compile_and_call_fx_graph #153449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b91ff71
30d47f1
14e69f8
88aecdc
3fcd84b
214277f
bf1809a
fa363b2
96c5fae
bf7c26a
51c25ee
0dd757f
d83b329
b0cef62
312d49d
d0eba02
8082547
be5ef29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,8 @@ | |
import torch.distributed as dist | ||
import torch.nn | ||
import torch.utils._pytree as pytree | ||
from torch import fx | ||
from torch import fx, Tensor | ||
from torch._C._dynamo import guards | ||
from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis | ||
from torch._guards import ( | ||
CompileContext, | ||
|
@@ -61,6 +62,7 @@ | |
guard_scalar, | ||
is_symbolic, | ||
ShapeEnv, | ||
Specialization, | ||
) | ||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts | ||
from torch.multiprocessing.reductions import StorageWeakRef | ||
|
@@ -158,6 +160,8 @@ | |
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") | ||
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") | ||
|
||
RootGuardManager = guards.RootGuardManager | ||
|
||
|
||
@dataclass(frozen=True) | ||
class VariableTrackerCacheKey: | ||
|
@@ -1675,7 +1679,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): | |
self.tracing_context.fake_mode = backend_fake_mode | ||
|
||
with self.restore_global_state(): | ||
compiled_fn = self.call_user_compiler(gm) | ||
compiled_fn = self.call_user_compiler(gm, self.example_inputs()) | ||
|
||
from torch.fx._lazy_graph_module import _LazyGraphModule | ||
|
||
|
@@ -1705,8 +1709,62 @@ def compile_and_call_fx_graph(self, tx, rv, root): | |
) | ||
|
||
counters["stats"]["unique_graphs"] += 1 | ||
# This is safe because we pre-process name to be unique | ||
self.install_global_unsafe(name, compiled_fn) | ||
if specializations := old_fake_mode.shape_env.specializations: | ||
specialization_guards = [] | ||
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} | ||
sources = [a.source for a in self.graphargs] | ||
for specialization in specializations: | ||
source_index = sources.index(specialization.source) | ||
check_fn_source = inspect.getsource(specialization.check_fn).strip() | ||
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined] | ||
specialization.check_fn, | ||
[check_fn_source], | ||
) | ||
|
||
log.debug( | ||
"Compiling backend specialized graph with specialization=%s", | ||
check_fn_source, | ||
) | ||
|
||
specialization_guards.append( | ||
( | ||
functools.partial( | ||
lambda idx, args, check_fn=check_fn: check_fn( | ||
args[idx] | ||
), | ||
source_index, | ||
), | ||
specialization, | ||
) | ||
) | ||
|
||
@torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") | ||
def specialized_dispatch(*args, **kwargs): | ||
bobrenjc93 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for check_fn, specialization in specialization_guards: | ||
if check_fn(args): | ||
if specialization in specialization_cache: | ||
return specialization_cache[specialization]( | ||
*args, **kwargs | ||
) | ||
|
||
with self.shape_env.patch_source_specialization( | ||
bobrenjc93 marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the shape_env shared between Specializations? (it looks like yes). If so, will this cause issues? Like if we have s0 == s1 and a specialization that says s0 == 2, does this cause s1 to always be 2 in subsequent Specializations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to get into a situation where:
My naive take is that each Specialization would want its own ShapeEnv There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's hard for me to imagine a case where we will add additional guards/replacements on a subsequent more specialized graph than the first generic graph. The tricky thing with ShapeEnv is that it's not serializable because guards aren't serializable so it's unclear how we could maintain multiple parallel ShapeEnvs. I just updated the PR to freeze the ShapeEnv after patching in the axioms/replacements so in the worst case we will fail the compile instead of silent incorrectness. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Offline discussion says: this may be a problem, we probably do want to clone the ShapeEnv, we will figure out how to resolve this in a followup. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I was imagining separate shape envs as well. maybe @ezyang can advise |
||
specialization.source, specialization.check_fn | ||
): | ||
# Modify gm so AOTAutogradCache key changes per specialization | ||
gm.meta["specialization"] = specialization | ||
example_inputs: list[Tensor] = list(args) | ||
specialization_cache[specialization] = ( | ||
self.call_user_compiler(gm, example_inputs) | ||
) | ||
|
||
return specialization_cache[specialization](*args, **kwargs) | ||
return compiled_fn(*args, **kwargs) | ||
|
||
# This is safe because we pre-process name to be unique | ||
self.install_global_unsafe(name, specialized_dispatch) | ||
else: | ||
# This is safe because we pre-process name to be unique | ||
self.install_global_unsafe(name, compiled_fn) | ||
|
||
assert self.root_tx is not None | ||
cg = PyCodegen(self.root_tx) | ||
|
@@ -1721,7 +1779,9 @@ def placeholders(self) -> list[fx.Node]: | |
def graphargs(self) -> list[GraphArg]: | ||
return [node.meta["grapharg"] for node in self.placeholders] | ||
|
||
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | ||
def call_user_compiler( | ||
self, gm: fx.GraphModule, example_inputs: list[Tensor] | ||
) -> CompiledFn: | ||
with dynamo_timed( | ||
"OutputGraph.call_user_compiler", | ||
phase_name="backend_compile", | ||
|
@@ -1730,9 +1790,11 @@ def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | |
waitcounter_name_override="compile_aot_autograd", | ||
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", | ||
): | ||
return self._call_user_compiler(gm) | ||
return self._call_user_compiler(gm, example_inputs) | ||
|
||
def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | ||
def _call_user_compiler( | ||
self, gm: fx.GraphModule, example_inputs: list[Tensor] | ||
) -> CompiledFn: | ||
assert self.compiler_fn is not None | ||
tot = 0 | ||
placeholders = [] | ||
|
@@ -1743,10 +1805,11 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | |
placeholders.append(node) | ||
increment_op_count(tot) | ||
for pl in placeholders: | ||
arg = pl.meta["grapharg"] | ||
# TODO: Why isn't this stored in meta :think: | ||
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 | ||
pl._dynamo_source = arg.source | ||
if not hasattr(pl, "_dynamo_source"): | ||
arg = pl.meta["grapharg"] | ||
# TODO: Why isn't this stored in meta :think: | ||
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 | ||
pl._dynamo_source = arg.source | ||
|
||
# NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640 | ||
gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] | ||
|
@@ -1762,7 +1825,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: | |
compiler_fn = self.compiler_fn | ||
if config.verify_correctness: | ||
compiler_fn = WrapperBackend(compiler_fn) | ||
compiled_fn = compiler_fn(gm, self.example_inputs()) | ||
compiled_fn = compiler_fn(gm, example_inputs) | ||
_step_logger()(logging.INFO, f"done compiler function {name}") | ||
assert callable(compiled_fn), "compiler_fn did not return callable" | ||
except (TensorifyScalarRestartAnalysis, ShortenTraceback): | ||
|
Uh oh!
There was an error while loading. Please reload this page.