8000 make should_swap more dde friendly · pytorch/pytorch@e91500a · GitHub
[go: up one dir, main page]

Skip to content

Commit e91500a

Browse files
committed
make should_swap more dde friendly
ghstack-source-id: a44923d Pull-Request: #162099
1 parent 2cb20c3 commit e91500a

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

torch/_prims_common/__init__.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,10 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool:
542542
def compute_elementwise_output_logical_to_physical_perm(
543543
*tensors, _skip_checks=False
544544
) -> list[int]:
545-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
545+
from torch.fx.experimental.symbolic_shapes import (
546+
guard_or_false,
547+
guard_size_oblivious,
548+
)
546549

547550
if not _skip_checks and len(tensors) == 0:
548551
msg = "Can't compute elementwise output strides for zero tensors!"
@@ -601,10 +604,26 @@ def should_swap(idx_a, idx_b):
601604
):
602605
continue
603606

604-
if guard_size_oblivious(stride_a < stride_b):
607+
# imitates if stride_a < stride_b : return -1 but
608+
# when stride_a = 1, we want stride_a < stride_b to be TRUE
609+
# when stride_b = 1, we want stride_a < stride_b to be FALSE
610+
if guard_or_false(stride_a == 1):
605611
return -1
606612

607-
if guard_size_oblivious(stride_a > stride_b):
613+
if not guard_or_false(stride_b == 1) and guard_size_oblivious(
614+
stride_a < stride_b
615+
):
616+
return -1
617+
618+
# imitates if stride_a > stride_b : return 1 but
619+
# when stride_b = 1, we want stride_a > stride_b to be TRUE
620+
# when stride_a = 1, we want stride_a > stride_b to be FALSE
621+
if guard_or_false(stride_b == 1):
622+
return 1
623+
624+
if not guard_or_false(stride_a == 1) and guard_size_oblivious(
625+
stride_a > stride_b
626+
):
608627
return 1
609628

610629
# stride_a == stride_b

0 commit comments

Comments
 (0)
0