@@ -1470,6 +1470,8 @@ class DimDynamic(Enum):
1470
1470
SIZE_LIKE_UNBACKED = 3
1471
1471
# Infer the strides from stride. If size is static, strides will be static as well.
1472
1472
INFER_STRIDE = 4
1473
+ # Like SIZE_LIKE_UNBACKED, but there's a hint
1474
+ OBLIVIOUS_SIZE = 5
1473
1475
1474
1476
1475
1477
# NB: These constraints affect both clients and backends: given some
@@ -3118,6 +3120,10 @@ def _init(
3118
3120
# Like var_to_val, but only set when propagate_real_tensors is on.
3119
3121
# Used as last resort to avoid GuardOnDataDependent error
3120
3122
self .unbacked_var_to_val : Dict [sympy .Symbol , sympy .Integer ] = {}
3123
+ # Like above, but used exclusively for OBLIVIOUS_SIZE. These
3124
+ # potentially could be put together but I am not sure, writing out
3125
+ # the logic individually before abstracting.
3126
+ self .oblivious_var_to_val : Dict [sympy .Symbol , sympy .Integer ] = {}
3121
3127
# Maps symbolic ints to their min/max range. These ranges
3122
3128
# are conservative: the int MUST fall in the range, but the
3123
3129
# range may contain ints which may not actually appear in
@@ -4080,12 +4086,20 @@ def create_symboolnode(self, sym: sympy.Expr) -> SymBool:
4080
4086
return SymBool (SymNode (sym , self , bool , None ))
4081
4087
4082
4088
def _log_create_unbacked_symbol (
4083
- self , prefix : str , symbol : sympy .Symbol , vr : ValueRanges
4089
+ self ,
4090
+ prefix : str ,
4091
+ symbol : sympy .Symbol ,
4092
+ vr : ValueRanges ,
4093
+ source : Optional [Source ] = None ,
4084
4094
) -> None :
4085
4095
is_debug = config .extended_debug_create_symbol is not None and str (
4086
4096
symbol
4087
4097
) in config .extended_debug_create_symbol .split ("," )
4088
- sloc , maybe_extra_debug = self ._get_stack_summary (is_debug )
4098
+ sloc : Union [str , SLoc ]
4099
+ if source is None :
4100
+ sloc , maybe_extra_debug = self ._get_stack_summary (is_debug )
4101
+ else :
4102
+ sloc , maybe_extra_debug = source .name (), ""
4089
4103
log .info (
4090
4104
"%s %s [%s, %s] %s%s" ,
4091
4105
prefix ,
@@ -4131,7 +4145,7 @@ def create_unbacked_symfloat(self) -> SymFloat:
4131
4145
return SymFloat (SymNode (symbol , self , float , None , fx_node = fx_node ))
4132
4146
4133
4147
@record_shapeenv_event ()
4134
- def create_unbacked_symint (self ) -> SymInt :
4148
+ def create_unbacked_symint (self , source : Optional [ Source ] = None ) -> SymInt :
4135
4149
"""Create a symbolic integer without a hint value"""
4136
4150
symbol : sympy .Symbol = make_symbol (
4137
4151
SymT .UNBACKED_INT , next (self .unbacked_symint_counter ), integer = True
@@ -4148,7 +4162,7 @@ def create_unbacked_symint(self) -> SymInt:
4148
4162
# Create a new FX placeholder and Z3 variable for 'symbol'.
4149
4163
fx_node = self ._create_fx_placeholder_and_z3var (symbol , int )
4150
4164
4151
- self ._log_create_unbacked_symbol ("create_unbacked_symint" , symbol , vr )
4165
+ self ._log_create_unbacked_symbol ("create_unbacked_symint" , symbol , vr , source )
4152
4166
4153
4167
return SymInt (SymNode (symbol , self , int , None , fx_node = fx_node ))
4154
4168
@@ -4261,14 +4275,15 @@ def create_symbol(
4261
4275
source_name
4262
4276
]
4263
4277
4264
- if dynamic_dim is DimDynamic .SIZE_LIKE_UNBACKED :
4265
- out = self .create_unbacked_symint ().node .expr
4278
+ if dynamic_dim in ( DimDynamic .SIZE_LIKE_UNBACKED , DimDynamic . OBLIVIOUS_SIZE ) :
4279
+ out = self .create_unbacked_symint (source ).node .expr
4266
4280
self ._constrain_range_for_size (out )
4267
- # TODO: maybe put the hint somewhere
4268
4281
if isinstance (symbolic_context , StatefulSymbolicContext ) and source_name :
4269
4282
symbolic_context .shape_env_to_source_to_symbol_cache [id (self )][
4270
4283
source_name
4271
4284
] = out
4285
+ if dynamic_dim is DimDynamic .OBLIVIOUS_SIZE :
4286
+ self .oblivious_var_to_val [out ] = val
4272
4287
return out
4273
4288
4274
4289
if do_not_specialize_zero_one :
@@ -5635,6 +5650,34 @@ def size_hint(
5635
5650
if allow_none :
5636
5651
return None
5637
5652
5653
+ if self .oblivious_var_to_val :
5654
+ # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
5655
+ correct_hint = result_expr .xreplace (self .oblivious_var_to_val )
5656
+ counterfactual_hint = result_expr .xreplace (
5657
+ {k : max (v , 2 ) for k , v in self .oblivious_var_to_val .items ()}
5658
+ )
5659
+ if (
5660
+ not correct_hint .free_symbols
5661
+ and not counterfactual_hint .free_symbols
5662
+ ):
5663
+ if correct_hint == counterfactual_hint :
5664
+ log .info ("oblivious_size hit %s -> %s" , expr , correct_hint )
5665
+ return correct_hint
5666
+ else :
5667
+ log .info (
5668
+ "oblivious_size counterfactual failed %s -> %s != %s" ,
5669
+ expr ,
5670
+ correct_hint ,
5671
+ counterfactual_hint ,
5672
+ )
5673
+ else :
5674
+ log .info (
5675
+ "oblivious_size miss %s -> %s (counterfactual: %s)" ,
5676
+ expr ,
5677
+ correct_hint ,
5678
+ counterfactual_hint ,
5679
+ )
5680
+
5638
5681
if self .unbacked_var_to_val :
5639
5682
unsound_expr = result_expr .xreplace (self .unbacked_var_to_val )
5640
5683
if not unsound_expr .free_symbols :
@@ -6388,9 +6431,39 @@ def compute_concrete_val() -> sympy.Basic:
6388
6431
expr , size_oblivious = True
6389
6432
)
6390
6433
6434
+ ok = False
6435
+
6391
6436
# Last ditch
6392
6437
if (
6393
- self .unbacked_var_to_val
6438
+ self .oblivious_var_to_val
6439
+ and not (
6440
+ correct_hint := orig_expr .xreplace (
6441
+ self .oblivious_var_to_val
6442
+ )
6443
+ ).free_symbols
6444
+ and not (
6445
+ counterfactual_hint := orig_expr .xreplace (
6446
+ {
6447
+ k : max (2 , v )
6448
+ for k , v in self .oblivious_var_to_val .items ()
6449
+ }
6450
+ )
6451
+ ).free_symbols
6452
+ and correct_hint == counterfactual_hint
6453
+ ):
6454
+ # TODO: better logging
6455
+ log .info (
6456
+ "oblivious_size %s -> %s (passed counterfactual)" ,
6457
+ orig_expr ,
6458
+ correct_hint ,
6459
+ )
6460
+ concrete_val = correct_hint
6461
+ # NB: do NOT transmute into runtime assert
6462
+ ok = True
6463
+
6464
+ if (
6465
+ not ok
6466
+ and self .unbacked_var_to_val
6394
6467
and not (
6395
6468
unsound_result := orig_expr .xreplace (
6396
6469
self .unbacked_var_to_val
@@ -6414,7 +6487,9 @@ def compute_concrete_val() -> sympy.Basic:
6414
6487
)
6415
6488
transmute_into_runtime_assert = True
6416
6489
concrete_val = unsound_result
6417
- else :
6490
+ ok = True
6491
+
6492
+ if not ok :
6418
6493
raise self ._make_data_dependent_error (
6419
6494
expr .xreplace (self .var_to_val ),
6420
6495
expr ,
0 commit comments