8000 thread through specialization to compile_fx · pytorch/pytorch@e6ecccb · GitHub
[go: up one dir, main page]

Skip to content

Commit e6ecccb

Browse files
committed
thread through specialization to compile_fx
ghstack-source-id: 5577884 Pull Request resolved: #152650
1 parent 9914ca3 commit e6ecccb

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

torch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2359,10 +2359,10 @@ def apply_options(self, options: _Optional[dict[str, _Any]]):
23592359
)
23602360
self.config[attr_name] = val
23612361

2362-
def __call__(self, model_, inputs_):
2362+
def __call__(self, model_, inputs_, **kwargs):
23632363
from torch._inductor.compile_fx import compile_fx
23642364

2365-
return compile_fx(model_, inputs_, config_patches=self.config)
2365+
return compile_fx(model_, inputs_, config_patches=self.config, **kwargs)
23662366

23672367
def get_compiler_config(self):
23682368
from torch._inductor.compile_fx import get_patched_config_dict

torch/_dynamo/output_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,8 +1499,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
14991499
compiled_fns = []
15001500
with self.restore_global_state():
15011501
for specialization in backend_specializations
1502-
modified_gm = specialize(gm, specialization)
1503-
compiled_fns.append(self.call_user_compiler(modified_gm))
1502+
compiled_fns.append(self.call_user_compiler(modified_gm, specialization))
15041503

15051504
from torch.fx._lazy_graph_module import _LazyGraphModule
15061505

@@ -1545,16 +1544,16 @@ def placeholders(self) -> list[fx.Node]:
15451544
def graphargs(self) -> list[GraphArg]:
15461545
return [node.meta["grapharg"] for node in self.placeholders]
15471546

1548-
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1547+
def call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
15491548
with dynamo_timed(
15501549
"OutputGraph.call_user_compiler",
15511550
phase_name="backend_compile",
15521551
log_pt2_compile_event=True,
15531552
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
15541553
):
1555-
return self._call_user_compiler(gm)
1554+
return self._call_user_compiler(gm, **kwargs)
15561555

1557-
def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1556+
def _call_user_compiler(self, gm: fx.GraphModule, **kwargs) -> CompiledFn:
15581557
assert self.compiler_fn is not None
15591558
tot = 0
15601559
placeholders = []
@@ -1584,7 +1583,7 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
15841583
compiler_fn = self.compiler_fn
15851584
if config.verify_correctness:
15861585
compiler_fn = WrapperBackend(compiler_fn)
1587-
compiled_fn = compiler_fn(gm, self.example_inputs())
1586+
compiled_fn = compiler_fn(gm, self.example_inputs(), **kwargs)
15881587
_step_logger()(logging.INFO, f"done compiler function {name}")
15891588
assert callable(compiled_fn), "compiler_fn did not return callable"
15901589
except (TensorifyScalarRestartAnalysis, ShortenTraceback):

torch/_functorch/aot_autograd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def process_inputs(
489489
fake_mode: FakeTensorMode,
490490
shape_env: Optional[ShapeEnv],
491491
ignore_shape_env: bool = False,
492+
specialization = None,
492493
) -> FakifiedFlatArgs:
493494
with fake_mode:
494495

@@ -547,6 +548,7 @@ def convert(idx, x):
547548
symbolic_context=symbolic_context,
548549
source=source,
549550
trace=trace,
551+
specialization=specialization
550552
)
551553
return result
552554

@@ -1083,6 +1085,7 @@ def aot_module_simplified(
10831085
cudagraphs: Optional[BoxedBool] = None,
10841086
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
10851087
ignore_shape_env: bool = False,
1088+
specializations = None,
10861089
) -> nn.Module:
10871090
"""
10881091
This is the simplified or low overhead version of aot_module. For frontends
@@ -1154,7 +1157,7 @@ def aot_module_simplified(
11541157
)
11551158
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
11561159
fake_flat_args = process_inputs(
1157-
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
1160+
full_args, aot_config, fake_mode, shape_env, ignore_shape_env, specializations=specializations
11581161
)
11591162

11601163
def dispatch_and_compile():

torch/_subclasses/fake_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def from_real_tensor(
354354
source: Optional[Source] = None,
355355
symbolic_context: Optional[SymbolicContext] = None,
356356
trace: bool = True,
357+
specialization = None,
357358
) -> FakeTensor:
358359
# see note [Tensor Fakification and Symbol Caching]
359360
if not symbolic_context and not source and shape_env:
@@ -408,6 +409,7 @@ def mk_fake_tensor(
408409
source=source,
409410
symbolic_context=symbolic_context,
410411
trace=trace,
412+
specialization=specialization
411413
)
412414
if out is NotImplemented:
413415
raise UnsupportedFakeTensorException("meta converter nyi")
@@ -2864,6 +2866,7 @@ def from_tensor(
28642866
source: Optional[Source] = None,
28652867
symbolic_context: Optional[SymbolicContext] = None,
28662868
trace: bool = True,
2869+
specialization = None,
28672870
) -> FakeTensor:
28682871
shape_env: Optional[ShapeEnv] = self.shape_env
28692872
if static_shapes is None:
@@ -2880,6 +2883,7 @@ def from_tensor(
28802883
source=source,
28812884
symbolic_context=symbolic_context,
28822885
trace=trace,
2886+
specialization=specialization,
28832887
)
28842888

28852889

torch/_subclasses/meta_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,15 @@ def describe_storage(
276276
return r
277277

278278
def describe_tensor(
279-
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
279+
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False, specialization=None
280280
) -> MetaTensorDesc:
281+
if specialization:
282+
t = t.to('meta')
283+
shape = list(t.shape)
284+
for i, hint in zip(specialization.idxs, specialization.hints):
285+
shape[i] = hint
286+
t = torch.ones(shape, dtype=t.dtype, device='meta')
287+
281288
is_leaf = safe_is_leaf(t)
282289
is_view = t._is_view()
283290
is_sparse = t.is_sparse
@@ -1844,6 +1851,7 @@ def __call__(
18441851
# when source is not None. Because we refakify after Dynamo is done,
18451852
# we don't want to dump info again from AOTAutograd, it is redundant.
18461853
trace: bool = True,
1854+
specialization = None,
18471855
) -> _TensorT:
18481856
callback_: _MetaTensorCallback[_TensorT]
18491857
if callback is None:
@@ -1886,7 +1894,7 @@ def __call__(
18861894

18871895
# Describe the tensor. NB: do NOT disable ambient modes, we may need
18881896
# to query them when figuring out what to put in here
1889-
t_desc = self.describer.describe_tensor(t, trace=trace)
1897+
t_desc = self.describer.describe_tensor(t, trace=trace, specialization=specialization)
18901898

18911899
if trace:
18921900
assert source is not None

0 commit comments

Comments
 (0)
0