@@ -552,7 +552,7 @@ def mark_unbacked(t, index, strict=False):
552
552
553
553
554
554
@forbid_in_graph
555
- def mark_dynamic (t , index , * , min = None , max = None ):
555
+ def mark_dynamic (t , index , * , min = None , max = None , backend_specializations = None ):
556
556
"""
557
557
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
558
558
@@ -587,14 +587,16 @@ def mark_dynamic(t, index, *, min=None, max=None):
587
587
if not hasattr (t , "_dynamo_dynamic_indices" ):
588
588
t ._dynamo_dynamic_indices = set ()
589
589
t ._dynamo_dynamic_range = set ()
590
+ t ._backend_specializations = {}
590
591
# TODO(voz): Should we bounds check?
591
592
t ._dynamo_dynamic_indices .add (index )
592
593
t ._dynamo_dynamic_range .add (_DimRange (index , min , max ))
594
+ t ._backend_specializations [index ] = backend_specializations
593
595
return
594
596
595
597
assert isinstance (index , (list , tuple ))
596
598
for i in index :
597
- mark_dynamic (t , i , min = min , max = max )
599
+ mark_dynamic (t , i , min = min , max = max , backend_specializations = backend_specializations )
598
600
599
601
600
602
@forbid_in_graph
0 commit comments