8000 Update on "thread through specialization to compile_fx" · pytorch/pytorch@67d9211 · GitHub
[go: up one dir, main page]

Skip to content

Commit 67d9211

Browse files
committed
Update on "thread through specialization to compile_fx"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
1 parent f67e6e8 commit 67d9211

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

torch/_subclasses/meta_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def meta_tensor(
882882
callback_: _MetaTensorCallback[_TensorT],
883883
source: Optional[Source],
884884
symbolic_context: Optional[SymbolicContext],
885+
specialization=specialization,
885886
) -> _TensorT:
886887
callback: _MetaTensorCallbackOptDevice = functools.partial(
887888
callback_, device=t.device
@@ -965,6 +966,7 @@ def sym_sizes_strides_storage_offset(
965966
[d in t.dynamo_dynamic_indices for d in range(t.ndim)],
966967
src,
967968
symbolic_context=symbolic_context,
969+
specialization=specialization,
968970
)
969971
else:
970972
return (t.size, t.stride, t.storage_offset)
@@ -1924,6 +1926,7 @@ def __call__(
19241926
callback_,
19251927
source,
19261928
symbolic_context,
1929+
specialization=specialization,
19271930
)
19281931

19291932
if type(t) is torch.nn.Parameter:

torch/fx/experimental/symbolic_shapes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,6 +4142,7 @@ def _create_symbolic_sizes_strides_storage_offset(
41424142
source: Source,
41434143
*,
41444144
symbolic_context: Optional[SymbolicContext] = None,
4145+
specialization = None,
41454146
) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]:
41464147
dim = len(ex_size)
41474148

@@ -4220,6 +4221,11 @@ def _create_symbolic_sizes_strides_storage_offset(
42204221
)
42214222
for i, (sym, hint) in enumerate(zip(size, ex_size))
42224223
]
4224+
4225+
for i, size in enumerate(sym_sizes):
4226+
if i in specialization.idxs:
4227+
expect_true(specialization.lambdas[i](size))
4228+
42234229
sym_stride = []
42244230
for i, stride_expr in enumerate(stride):
42254231
# NB: Don't duck size the stride; instead use the expression

0 commit comments

Comments
 (0)
0