8000 [Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.… · pytorch/pytorch@bd8d7b1 · GitHub
[go: up one dir, main page]

Skip to content

Commit bd8d7b1

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.__call__ (#146270)
Pull Request resolved: #146270 Approved by: https://github.com/zou3519
1 parent fd73ae2 commit bd8d7b1

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

test/dynamo/test_higher_order_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2283,7 +2283,8 @@ def body(x):
22832283

22842284
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
22852285
# There is graph break right when we enter body of map
2286-
self.assertEqual(len(backend.graphs), 0)
2286+
# Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
2287+
self.assertEqual(len(backend.graphs), 8)
22872288
self.assertEqual(
22882289
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
22892290
)
@@ -2319,7 +2320,8 @@ def body(x):
23192320
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
23202321
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
23212322

2322-
self.assertEqual(len(backend.graphs), 0)
2323+
# Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
2324+
self.assertEqual(len(backend.graphs), 9)
23232325
self.assertEqual(res, eager)
23242326

23252327
def test_wrap_subgraph_name_is_valid(self):

torch/_ops.py

-5
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,6 @@ def check_overloaded(arg):
456456

457457
@abc.abstractmethod
458458
def __call__(self, /, *args, **kwargs):
459-
# Dynamo already traces the body of HigherOrderOp beforehand when it
460-
# so no need to trace into it.
461-
from torch._dynamo import disable
462-
463-
@disable
464459
def wrapper():
465460
flat_args = _to_flat_tuple(args, kwargs)
466461
if torch.overrides.has_torch_function(flat_args):

0 commit comments

Comments
 (0)
0