8000 make should_swap more dde friendly · pytorch/pytorch@e44d0e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit e44d0e3

Browse files
committed
make should_swap more dde friendly
ghstack-source-id: 31831f8 Pull-Request: #162099
1 parent 8bde4fb commit e44d0e3

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

torch/_prims_common/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
542542
def compute_elementwise_output_logical_to_physical_perm(
543543
*tensors, _skip_checks=False
544544
) -> list[int]:
545-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
545+
from torch.fx.experimental.symbolic_shapes import (
546+
guard_or_false,
547+
guard_size_oblivious,
548+
)
546549

547550
if not _skip_checks and len(tensors) == 0:
548551
msg = "Can't compute elementwise output strides for zero tensors!"
@@ -595,12 +598,23 @@ def should_swap(idx_a, idx_b):
595598
for tensor in tensors:
596599
stride_a = tensor.stride()[idx_a]
597600
stride_b = tensor.stride()[idx_b]
598-
599601
if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
600602
stride_b == 0
601603
):
602604
continue
603605

606+
if guard_or_false(stride_a == stride_b):
607+
if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
608+
return 1
609+
610+
# when stride_a = 1, we want stride_a < stride_b to be TRUE
611+
# when stride_b = 1, we want stride_a < stride_b to be FALSE
612+
if guard_or_false(stride_a == 1):
613+
return -1
614+
615+
if guard_or_false(stride_b == 1):
616+
return 1
617+
604618
if guard_size_oblivious(stride_a < stride_b):
605619
return -1
606620

0 commit comments

Comments
 (0)
0