File tree Expand file tree Collapse file tree 1 file changed +22
-3
lines changed Expand file tree Collapse file tree 1 file changed +22
-3
lines changed Original file line number Diff line number Diff line change @@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
542542def compute_elementwise_output_logical_to_physical_perm (
543543 * tensors , _skip_checks = False
544544) -> list [int ]:
545- from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
545+ from torch .fx .experimental .symbolic_shapes import (
546+ guard_or_false ,
547+ guard_size_oblivious ,
548+ )
546549
547550 if not _skip_checks and len (tensors ) == 0 :
548551 msg = "Can't compute elementwise output strides for zero tensors!"
@@ -601,10 +604,26 @@ def should_swap(idx_a, idx_b):
601604 ):
602605 continue
603606
604- if guard_size_oblivious (stride_a < stride_b ):
607+ # imitates if stride_a < stride_b : return -1 but
608+ # when stride_a = 1, we want stride_a < stride_b to be TRUE
609+ # when stride_b = 1, we want stride_a < stride_b to be FALSE
610+ if guard_or_false (stride_a == 1 ):
605611 return - 1
606612
607- if guard_size_oblivious (stride_a > stride_b ):
613+ if not guard_or_false (stride_b == 1 ) and guard_size_oblivious (
614+ stride_a < stride_b
615+ ):
616+ return - 1
617+
618+ # imitates if stride_a > stride_b : return 1 but
619+ # when stride_b = 1, we want stride_a > stride_b to be TRUE
620+ # when stride_a = 1, we want stride_a > stride_b to be FALSE
621+ if guard_or_false (stride_b == 1 ):
622+ return 1
623+
624+ if not guard_or_false (stride_a == 1 ) and guard_size_oblivious (
625+ stride_a > stride_b
626+ ):
608627 return 1
609628
610629 # stride_a == stride_b
You can’t perform that action at this time.
0 commit comments