8000 [NJT] Fix inference mode for composite implicit ops without nested-specific kernel by soulitzer · Pull Request #146633 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[NJT] Fix inference mode for composite implicit ops without nested-specific kernel #146633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[NJT] Fix inference mode for composite implicit ops without nested-sp…
…ecific kernel

[ghstack-poisoned]
  • Loading branch information
soulitzer committed Feb 6, 2025
commit 313f65f429db67d74b1d226c1df920b2afb4e4d5
12 changes: 12 additions & 0 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7624,6 +7624,18 @@ def f(nt):
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))

def test_dropout_inference_mode(self, device):
seq_len = 32
embed_dim = 128

nt = torch.nested.nested_tensor([
torch.randn(11, seq_len, embed_dim, device=device),
torch.randn(11, seq_len, embed_dim, device=device)
], layout=torch.jagged, device=device)

with torch.inference_mode():
torch.nn.functional.dropout(nt, p=0.05)

@dtypes(torch.float32, torch.double, torch.half)
def test_unbind_backward(self, device, dtype):
nt = torch.nested.nested_tensor(
Expand Down
15 changes: 11 additions & 4 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,17 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):

# Poor man's redispatch for composite ops. This becomes relevant under inference
# mode, where disabling autograd key dispatch prevents decomposition.
dk = torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
with torch.overrides.enable_reentrant_dispatch():
return func._op_dk(dk, *args, **kwargs)
all_dks = (
# We want to handle both the cases where NestedTensor overrides the
# composite implicit autograd kernel, and the case where it doesn't.
# Prioritize calling into NestedTensor's kernel if it exists.
torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
torch._C.DispatchKey.CompositeImplicitAutograd,
)
for dk in all_dks:
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
with torch.overrides.enable_reentrant_dispatch():
return func._op_dk(dk, *args, **kwargs)

raise NotImplementedError(func)

Expand Down
Loading
0