8000 Update on "[NJT] Fix inference mode for composite implicit ops withou… · pytorch/pytorch@9e27eb7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9e27eb7

Browse files
committed
Update on "[NJT] Fix inference mode for composite implicit ops without nested-specific kernel"
[ghstack-poisoned]
2 parents 3425f31 + 53f8436 commit 9e27eb7

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

test/test_nestedtensor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7628,10 +7628,14 @@ def test_dropout_inference_mode(self, device):
76287628
seq_len = 32
76297629
embed_dim = 128
76307630

7631-
nt = torch.nested.nested_tensor([
7632-
torch.randn(11, seq_len, embed_dim, device=device),
7633-
torch.randn(11, seq_len, embed_dim, device=device)
7634-
], layout=torch.jagged, device=device)
7631+
nt = torch.nested.nested_tensor(
7632+
[
7633+
torch.randn(11, seq_len, embed_dim, device=device),
7634+
torch.randn(11, seq_len, embed_dim, device=device),
7635+
],
7636+
layout=torch.jagged,
7637+
device=device,
7638+
)
76357639

76367640
with torch.inference_mode():
76377641
torch.nn.functional.dropout(nt, p=0.05)

0 commit comments

Comments
 (0)
0