File tree Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Expand file tree Collapse file tree 1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
542542def compute_elementwise_output_logical_to_physical_perm (
543543 * tensors , _skip_checks = False
544544) -> list [int ]:
545- from torch .fx .experimental .symbolic_shapes import guard_size_oblivious
545+ from torch .fx .experimental .symbolic_shapes import (
546+ guard_or_false ,
547+ guard_size_oblivious ,
548+ )
546549
547550 if not _skip_checks and len (tensors ) == 0 :
548551 msg = "Can't compute elementwise output strides for zero tensors!"
@@ -595,12 +598,23 @@ def should_swap(idx_a, idx_b):
595598 for tensor in tensors :
596599 stride_a = tensor .stride ()[idx_a ]
597600 stride_b = tensor .stride ()[idx_b ]
598-
599601 if guard_size_oblivious (stride_a == 0 ) or guard_size_oblivious (
600602 stride_b == 0
601603 ):
602604 continue
603605
606+ if guard_or_false (stride_a == stride_b ):
607+ if guard_size_oblivious (shape [idx_a ] > shape [idx_b ]):
608+ return 1
609+
610+ # when stride_a = 1, we want stride_a < stride_b to be TRUE
611+ # when stride_b = 1, we want stride_a < stride_b to be FALSE
612+ if guard_or_false (stride_a == 1 ):
613+ return - 1
614+
615+ if guard_or_false (stride_b == 1 ):
616+ return 1
617+
604618 if guard_size_oblivious (stride_a < stride_b ):
605619 return - 1
606620
You can’t perform that action at this time.
0 commit comments