8000 Inductor respects exact strides on custom ops by default (#150511) · pytorch/pytorch@3d777ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 3d777ba

Browse files
zou3519pytorchmergebot
authored andcommitted
Inductor respects exact strides on custom ops by default (#150511)
If a tag is not specified on a custom operator, then inductor will assume that it needs exact strides. Test Plan: - tests + CI Pull Request resolved: #150511 Approved by: https://github.com/eellison, https://github.com/shunting314 ghstack dependencies: #148104
1 parent 2b37a72 commit 3d777ba

File tree

4 files changed

+4
-5
lines changed

4 files changed

+4
-5
lines changed

test/inductor/test_triton_kernels.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,7 +3451,6 @@ def impl2(x):
34513451

34523452
lib.define(
34533453
"add_op(Tensor x, Tensor y) -> Tensor",
3454-
tags=[torch._C.Tag.needs_exact_strides],
34553454
)
34563455

34573456
def impl(x, y):
@@ -3465,7 +3464,6 @@ def meta(x, y):
34653464

34663465
lib.define(
34673466
"add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()",
3468-
tags=[torch._C.Tag.needs_exact_strides],
34693467
)
34703468

34713469
def impl_out(x, y, out):

test/test_custom_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3838,14 +3838,15 @@ def vmap(info, in_dims, w, x=2, *, y=3, z):
38383838
self.assertEqual(result, w * 2 * 3 * 42)
38393839

38403840
def test_layout_constraint_tags(self):
3841+
needs_exact_strides = torch._C.Tag.needs_exact_strides
38413842
needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order
38423843
flexible_layout = torch._C.Tag.flexible_layout
38433844
# (tags, the result of the tag inference)
38443845
tests = [
38453846
({needs_fixed_stride_order}, needs_fixed_stride_order),
38463847
({flexible_layout}, flexible_layout),
38473848
# If no tags are provided, then the following is the default
3848-
(set(), needs_fixed_stride_order),
3849+
(set(), needs_exact_strides),
38493850
# If multiple tags are provided, then we use the most constrained tag.
38503851
({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order),
38513852
]

torch/_functorch/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def remote_autograd_cache_default() -> Optional[bool]:
186186
# ProxyTensor tracing.
187187
custom_op_default_layout_constraint: Literal[
188188
"needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
189-
] = "needs_fixed_stride_order"
189+
] = "needs_exact_strides"
190190

191191

192192
# Run aot eager decomp partition with CrossRefFakeMode

torch/_library/custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None:
615615

616616
lib.define(
617617
schema_str,
618-
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order, *tags],
618+
tags=[_C.Tag.pt2_compliant_tag, *tags],
619619
)
620620
self._opoverload = utils.lookup_op(self._qualname)
621621

0 commit comments

Comments
 (0)
0