10000 enable flex attention ut · pytorch/pytorch@b21033c · GitHub
[go: up one dir, main page]

Skip to content

Commit b21033c

Browse files
committed
enable flex attention ut
1 parent 90b6f0b commit b21033c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/inductor/test_flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3988,7 +3988,7 @@ def causal_mask(b, h, q, kv):
39883988
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
39893989

39903990
offset = torch.arange(8, device=device)
3991-
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048)
3991+
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device)
39923992
self.assertEqual(block_mask.sparsity(), 29.1015625)
39933993
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
39943994
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())

0 commit comments

Comments
 (0)
0