8000 Add automatic_dynamic_shapes_mark_as == "oblivious" (#141444) · pytorch/pytorch@8630096 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8630096

Browse files
ezyangpytorchmergebot
authored andcommitted
Add automatic_dynamic_shapes_mark_as == "oblivious" (#141444)
Fixes #137100 Should also add a mark_oblivious API for manual control. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #141444 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #141415
1 parent e53696b commit 8630096

File tree

3 files changed

+119
-9
lines changed

3 files changed

+119
-9
lines changed

test/dynamo/test_recompiles.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,39 @@ def f(x):
393393

394394
self.assertEqual(counter.frame_count, 2) # not three or four!
395395

396+
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
397+
def test_automatic_dynamic_shapes_mark_as_oblivious(self):
398+
counter = torch._dynamo.testing.CompileCounter()
399+
< 8000 /code>400+
def f(x):
401+
if x.size(0) < 10:
402+
return x * 1
403+
else:
404+
return x + 10
405+
406+
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
407+
408+
for i in [3, 2, 1, 0]:
409+
self.assertEqual(f(torch.zeros(i)), opt_f(torch.zeros(i)))
410+
411+
self.assertEqual(counter.frame_count, 2) # not three or four!
412+
413+
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
414+
def test_automatic_dynamic_shapes_mark_as_oblivious_fail_counterfactual(self):
415+
counter = torch._dynamo.testing.CompileCounter()
416+
417+
def f(x):
418+
if x.size(0) < 2:
419+
return x * 1
420+
else:
421+
return x + 10
422+
423+
opt_f = torch.compile(backend=counter, fullgraph=True)(f)
424+
425+
opt_f(torch.randn(1))
426+
with self.assertRaises(torch._dynamo.exc.UserError):
427+
opt_f(torch.randn(0))
428+
396429

397430
if __name__ == "__main__":
398431
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,6 +2524,8 @@ def get_automatic_dynamic_shapes_mark_as():
25242524
return DimDynamic.DYNAMIC
25252525
elif config.automatic_dynamic_shapes_mark_as == "unbacked":
25262526
return DimDynamic.SIZE_LIKE_UNBACKED
2527+
elif config.automatic_dynamic_shapes_mark_as == "oblivious":
2528+
return DimDynamic.OBLIVIOUS_SIZE
25272529
else:
25282530
raise ValueError(
25292531
f"invalid automatic_dynamic_shapes_mark_as = {config.automatic_dynamic_shapes_mark_as}"

torch/fx/experimental/symbolic_shapes.py

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,8 @@ class DimDynamic(Enum):
14701470
SIZE_LIKE_UNBACKED = 3
14711471
# Infer the strides from stride. If size is static, strides will be static as well.
14721472
INFER_STRIDE = 4
1473+
# Like SIZE_LIKE_UNBACKED, but there's a hint
1474+
OBLIVIOUS_SIZE = 5
14731475

14741476

14751477
# NB: These constraints affect both clients and backends: given some
@@ -3118,6 +3120,10 @@ def _init(
31183120
# Like var_to_val, but only set when propagate_real_tensors is on.
31193121
# Used as last resort to avoid GuardOnDataDependent error
31203122
self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
3123+
# Like above, but used exclusively for OBLIVIOUS_SIZE. These
3124+
# potentially could be put together but I am not sure, writing out
3125+
# the logic individually before abstracting.
3126+
self.oblivious_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
31213127
# Maps symbolic ints to their min/max range. These ranges
31223128
# are conservative: the int MUST fall in the range, but the
31233129
# range may contain ints which may not actually appear in
@@ -4080,12 +4086,20 @@ def create_symboolnode(self, sym: sympy.Expr) -> SymBool:
40804086
return SymBool(SymNode(sym, self, bool, None))
40814087

40824088
def _log_create_unbacked_symbol(
4083-
self, prefix: str, symbol: sympy.Symbol, vr: ValueRanges
4089+
self,
4090+
prefix: str,
4091+
symbol: sympy.Symbol,
4092+
vr: ValueRanges,
4093+
source: Optional[Source] = None,
40844094
) -> None:
40854095
is_debug = config.extended_debug_create_symbol is not None and str(
40864096
symbol
40874097
) in config.extended_debug_create_symbol.split(",")
4088-
sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
4098+
sloc: Union[str, SLoc]
4099+
if source is None:
4100+
sloc, maybe_extra_debug = self._get_stack_summary(is_debug)
4101+
else:
4102+
sloc, maybe_extra_debug = source.name(), ""
40894103
log.info(
40904104
"%s %s [%s, %s] %s%s",
40914105
prefix,
@@ -4131,7 +4145,7 @@ def create_unbacked_symfloat(self) -> SymFloat:
41314145
return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
41324146

41334147
@record_shapeenv_event()
4134-
def create_unbacked_symint(self) -> SymInt:
4148+
def create_unbacked_symint(self, source: Optional[Source] = None) -> SymInt:
41354149
"""Create a symbolic integer without a hint value"""
41364150
symbol: sympy.Symbol = make_symbol(
41374151
SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True
@@ -4148,7 +4162,7 @@ def create_unbacked_symint(self) -> SymInt:
41484162
# Create a new FX placeholder and Z3 variable for 'symbol'.
41494163
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
41504164

4151-
self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr)
4165+
self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr, source)
41524166

41534167
return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
41544168

@@ -4261,14 +4275,15 @@ def create_symbol(
42614275
source_name
42624276
]
42634277

4264-
if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED:
4265-
out = self.create_unbacked_symint().node.expr
4278+
if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE):
4279+
out = self.create_unbacked_symint(source).node.expr
42664280
self._constrain_range_for_size(out)
4267-
# TODO: maybe put the hint somewhere
42684281
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
42694282
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][
42704283
source_name
42714284
] = out
4285+
if dynamic_dim is DimDynamic.OBLIVIOUS_SIZE:
4286+
self.oblivious_var_to_val[out] = val
42724287
return out
42734288

42744289
if do_not_specialize_zero_one:
@@ -5635,6 +5650,34 @@ def size_hint(
56355650
if allow_none:
56365651
return None
56375652

5653+
if self.oblivious_var_to_val:
5654+
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
5655+
correct_hint = result_expr.xreplace(self.oblivious_var_to_val)
5656+
counterfactual_hint = result_expr.xreplace(
5657+
{k: max(v, 2) for k, v in self.oblivious_var_to_val.items()}
5658+
)
5659+
if (
5660+
not correct_hint.free_symbols
5661+
and not counterfactual_hint.free_symbols
5662+
):
5663+
if correct_hint == counterfactual_hint:
5664+
log.info("oblivious_size hit %s -> %s", expr, correct_hint)
5665+
return correct_hint
5666+
else:
5667+
log.info(
5668+
"oblivious_size counterfactual failed %s -> %s != %s",
5669+
expr,
5670+
correct_hint,
5671+
counterfactual_hint,
5672+
)
5673+
else:
5674+
log.info(
5675+
"oblivious_size miss %s -> %s (counterfactual: %s)",
5676+
expr,
5677+
correct_hint,
5678+
counterfactual_hint,
5679+
)
5680+
56385681
if self.unbacked_var_to_val:
56395682
unsound_expr = result_expr.xreplace(self.unbacked_var_to_val)
56405683
if not unsound_expr.free_symbols:
@@ -6388,9 +6431,39 @@ def compute_concrete_val() -> sympy.Basic:
63886431
expr, size_oblivious=True
63896432
)
63906433

6434+
ok = False
6435+
63916436
# Last ditch
63926437
if (
6393-
self.unbacked_var_to_val
6438+
self.oblivious_var_to_val
6439+
and not (
6440+
correct_hint := orig_expr.xreplace(
6441+
self.oblivious_var_to_val
6442+
)
6443+
).free_symbols
6444+
and not (
6445+
counterfactual_hint := orig_expr.xreplace(
6446+
{
6447+
k: max(2, v)
6448+
for k, v in self.oblivious_var_to_val.items()
6449+
}
6450+
)
6451+
).free_symbols
6452+
and correct_hint == counterfactual_hint
6453+
):
6454+
# TODO: better logging
6455+
log.info(
6456+
"oblivious_size %s -> %s (passed counterfactual)",
6457+
orig_expr,
6458+
correct_hint,
6459+
)
6460+
concrete_val = correct_hint
6461+
# NB: do NOT transmute into runtime assert
6462+
ok = True
6463+
6464+
if (
6465+
not ok
6466+
and self.unbacked_var_to_val
63946467
and not (
63956468
unsound_result := orig_expr.xreplace(
63966469
self.unbacked_var_to_val
@@ -6414,7 +6487,9 @@ def compute_concrete_val() -> sympy.Basic:
64146487
)
64156488
transmute_into_runtime_assert = True
64166489
concrete_val = unsound_result
6417-
else:
6490+
ok = True
6491+
6492+
if not ok:
64186493
raise self._make_data_dependent_error(
64196494
expr.xreplace(self.var_to_val),
64206495
expr,

0 commit comments

Comments
 (0)
0