8000 [NJT] Fix inference mode for composite implicit ops without nested-sp… · pytorch/pytorch@9035fd1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9035fd1

Browse files
committed
[NJT] Fix inference mode for composite implicit ops without nested-specific kernel
ghstack-source-id: a72b64b Pull Request resolved: #146633
1 parent e57cdb8 commit 9035fd1

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

test/test_nestedtensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7624,6 +7624,22 @@ def f(nt):
76247624
for dynamic in [False, True, None]:
76257625
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
76267626

7627+
def test_dropout_inference_mode(self, device):
7628+
seq_len = 32
7629+
embed_dim = 128
7630+
7631+
nt = torch.nested.nested_tensor(
7632+
[
7633+
torch.randn(11, seq_len, embed_dim, device=device),
7634+
torch.randn(11, seq_len, embed_dim, device=device),
7635+
],
7636+
layout=torch.jagged,
7637+
device=device,
7638+
)
7639+
7640+
with torch.inference_mode():
7641+
torch.nn.functional.dropout(nt, p=0.05)
7642+
76277643
@dtypes(torch.float32, torch.double, torch.half)
76287644
def test_unbind_backward(self, device, dtype):
76297645
nt = torch.nested.nested_tensor(

torch/nested/_internal/nested_tensor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,17 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
325325

326326
# Poor man's redispatch for composite ops. This becomes relevant under inference
327327
# mode, where disabling autograd key dispatch prevents decomposition.
328-
dk = torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor
329-
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
330-
with torch.overrides.enable_reentrant_dispatch():
331-
return func._op_dk(dk, *args, **kwargs)
328+
all_dks = (
329+
# We want to handle both the cases where NestedTensor overrides the
330+
# composite implicit autograd kernel, and the case where it doesn't.
331+
# Prioritize calling into NestedTensor's kernel if it exists.
332+
torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
333+
torch._C.DispatchKey.CompositeImplicitAutograd,
334+
)
335+
for dk in all_dks:
336+
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
337+
with torch.overrides.enable_reentrant_dispatch():
338+
return func._op_dk(dk, *args, **kwargs)
332339

333340
raise NotImplementedError(func)
334341

0 commit comments

Comments
 (0)
0