8000 thread through specialization to compile_fx by bobrenjc93 · Pull Request #152650 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

thread through specialization to compile_fx #152650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
)
self.config[attr_name] = val

def __call__(self, model_, inputs_):
def __call__(self, model_, inputs_, **kwargs):
from torch._inductor.compile_fx import compile_fx

return compile_fx(model_, inputs_, config_patches=self.config)
return compile_fx(model_, inputs_, config_patches=self.config, **kwargs)

def get_compiler_config(self):
from torch._inductor.compile_fx import get_patched_config_dict
Expand Down
13 changes: 6 additions & 7 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,9 +1498,8 @@ def compile_and_call_fx_graph(self, tx, rv, root):

compiled_fns = []
with self.restore_global_state():
for specialization in backend_specializations
modified_gm = specialize(gm, specialization)
compiled_fns.append(self.call_user_compiler(modified_gm))
for specialization in old_fake_mode.shape_env.backend_specializations
compiled_fns.append(self.call_user_compiler(modified_gm, specialization))

from torch.fx._lazy_graph_module import _LazyGraphModule

Expand Down Expand Up @@ -1545,16 +1544,16 @@ def placeholders(self) -> list[fx.Node]:
def graphargs(self) -> list[GraphArg]:
return [node.meta["grapharg"] for node in self.placeholders]

def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
with dynamo_timed(
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)
return self._call_user_compiler(gm, **kwargs)

def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
def _call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
assert self.compiler_fn is not None
tot = 0
placeholders = []
Expand Down Expand Up @@ -1584,7 +1583,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
compiler_fn = self.compiler_fn
if config.verify_correctness:
compiler_fn = WrapperBackend(compiler_fn)
compiled_fn = compiler_fn(gm, self.example_inputs())
compiled_fn = compiler_fn(gm, self.example_inputs(), **kwargs)
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except (TensorifyScalarRestartAnalysis, ShortenTraceback):
Expand Down
5 changes: 4 additions & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def process_inputs(
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
ignore_shape_env: bool = False,
specialization = None,
) -> FakifiedFlatArgs:
with fake_mode:

Expand Down Expand Up @@ -547,6 +548,7 @@ def convert(idx, x):
symbolic_context=symbolic_context,
source=source,
trace=trace,
specialization=specialization
)
return result

Expand Down Expand Up @@ -1083,6 +1085,7 @@ def aot_module_simplified(
cudagraphs: Optional[BoxedBool] = None,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
ignore_shape_env: bool = False,
specializations = None,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
Expand Down Expand Up @@ -1154,7 +1157,7 @@ def aot_module_simplified(
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
fake_flat_args = process_inputs(
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
full_args, aot_config, fake_mode, shape_env, ignore_shape_env, specializations=specializations
)

def dispatch_and_compile():
Expand Down
4 changes: 4 additions & 0 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def from_real_tensor(
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
trace: bool = True,
specialization = None,
) -> FakeTensor:
# see note [Tensor Fakification and Symbol Caching]
if not symbolic_context and not source and shape_env:
Expand Down Expand Up @@ -408,6 +409,7 @@ def mk_fake_tensor(
source=source,
symbolic_context=symbolic_context,
trace=trace,
specialization=specialization
)
if out is NotImplemented:
raise UnsupportedFakeTensorException("meta converter nyi")
Expand Down Expand Up @@ -2864,6 +2866,7 @@ def from_tensor(
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
trace: bool = True,
specialization = None,
) -> FakeTensor:
shape_env: Optional[ShapeEnv] = self.shape_env
if static_shapes is None:
Expand All @@ -2880,6 +2883,7 @@ def from_tensor(
source=source,
symbolic_context=symbolic_context,
trace=trace,
specialization=specialization,
)


Expand Down
15 changes: 13 additions & 2 deletions torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,15 @@ def describe_storage(
return r

def describe_tensor(
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False, specialization=None
) -> MetaTensorDesc:
if specialization:
t = t.to('meta')
shape = list(t.shape)
for i, hint in zip(specialization.idxs, specialization.hints):
shape[i] = hint
t = torch.ones(shape, dtype=t.dtype, device='meta')

is_leaf = safe_is_leaf(t)
is_view = t._is_view()
is_sparse = t.is_sparse
Expand Down Expand Up @@ -875,6 +882,7 @@ def meta_tensor(
callback_: _MetaTensorCallback[_TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
specialization=specialization,
) -> _TensorT:
callback: _MetaTensorCallbackOptDevice = functools.partial(
callback_, device=t.device
Expand Down Expand Up @@ -958,6 +966,7 @@ def sym_sizes_strides_storage_offset(
[d in t.dynamo_dynamic_indices for d in range(t.ndim)],
src,
symbolic_context=symbolic_context,
specialization=specialization,
)
else:
return (t.size, t.stride, t.storage_offset)
Expand Down Expand Up @@ -1844,6 +1853,7 @@ def __call__(
# when source is not None. Because we refakify after Dynamo is done,
# we don't want to dump info again from AOTAutograd, it is redundant.
trace: bool = True,
specialization = None,
) -> _TensorT:
callback_: _MetaTensorCallback[_TensorT]
if callback is None:
Expand Down Expand Up @@ -1886,7 +1896,7 @@ def __call__(

# Describe the tensor. NB: do NOT disable ambient modes, we may need
# to query them when figuring out what to put in here
t_desc = self.describer.describe_tensor(t, trace=trace)
t_desc = self.describer.describe_tensor(t, trace=trace, specialization=specialization)

if trace:
assert source is not None
Expand Down Expand Up @@ -1916,6 +1926,7 @@ def __call__(
callback_,
source,
symbolic_context,
specialization=specialization,
)

if type(t) is torch.nn.Parameter:
Expand Down
19 changes: 19 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,12 @@ def find_symbol_binding_fx_nodes(
return r


@dataclass
class BackendSpecialization:
symbol: sympy.Symbol
hint: int
specialization: Callable

# Analogous to ConvertIntSource
@dataclass(frozen=True)
class ConvertIntKey:
Expand Down Expand Up @@ -3561,6 +3567,8 @@ def _init(

self.trace_asserts = trace_asserts

self.backend_specializations = []

from torch.fx.experimental.validator import translation_validation_enabled

self._translation_validation_enabled = translation_validation_enabled()
Expand Down Expand Up @@ -4044,6 +4052,11 @@ def _produce_dyn_sizes_from_int_tuple(
do_not_specialize_zero_one=config.backed_size_oblivious,
symbolic_context=symbolic_context,
)
for specialization in symbolic_context.backend_specializations:
self.backend_specializations.append(BackendSpecialization(
sym,
*specialization,
))
if (
config.backed_size_oblivious
and isinstance(sym, sympy.Symbol) # could be static
Expand Down Expand Up @@ -4142,6 +4155,7 @@ def _create_symbolic_sizes_strides_storage_offset(
source: Source,
*,
symbolic_context: Optional[SymbolicContext] = None,
specialization = None,
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
dim = len(ex_size)

Expand Down Expand Up @@ -4220,6 +4234,11 @@ def _create_symbolic_sizes_strides_storage_offset(
)
for i, (sym, hint) in enumerate(zip(size, ex_size))
]

for i, size in enumerate(sym_sizes):
if i in specialization.idxs:
expect_true(specialization.lambdas[i](size))

sym_stride = []
for i, stride_expr in enumerate(stride):
# NB: Don't duck size the stride; instead use the expression
Expand Down
Loading
0