8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fd73ae2 commit bd8d7b1Copy full SHA for bd8d7b1
test/dynamo/test_higher_order_ops.py
@@ -2283,7 +2283,8 @@ def body(x):
2283
2284
res = mod_for_compile(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2285
# There is graph break right when we enter body of map
2286
- self.assertEqual(len(backend.graphs), 0)
+ # Since we are tracing through the Python dispatch logic, it ends up 8 graphs.
2287
+ self.assertEqual(len(backend.graphs), 8)
2288
self.assertEqual(
2289
res, mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2290
)
@@ -2319,7 +2320,8 @@ def body(x):
2319
2320
eager = mod_for_eager(torch.Tensor([[6, 4, 5], [3, 4, 5], [6, 6, 6]]))
2321
2322
2323
+ # Since we are tracing through the Python dispatch logic, it ends up 9 graphs.
2324
+ self.assertEqual(len(backend.graphs), 9)
2325
self.assertEqual(res, eager)
2326
2327
def test_wrap_subgraph_name_is_valid(self):
torch/_ops.py
@@ -456,11 +456,6 @@ def check_overloaded(arg):
456
457
@abc.abstractmethod
458
def __call__(self, /, *args, **kwargs):
459
- # Dynamo already traces the body of HigherOrderOp beforehand when it
460
- # so no need to trace into it.
461
- from torch._dynamo import disable
462
-
463
- @disable
464
def wrapper():
465
flat_args = _to_flat_tuple(args, kwargs)
466
if torch.overrides.has_torch_function(flat_args):
0 commit comments