8000 enable flex decoding case · pytorch/pytorch@90b6f0b · GitHub
[go: up one dir, main page]

Skip to content

Commit 90b6f0b

Browse files
committed
enable flex decoding case
1 parent b83e482 commit 90b6f0b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

test/inductor/test_flex_decoding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,13 @@ def mask_mod(b, h, q, kv):
10161016
return kv >= q + offset_tensor
10171017

10181018
block_mask = create_block_mask(mask_mod, B, 1, Q_S, KV_S, device=device)
1019-
self.run_test(Q_S=Q_S, KV_S=KV_S, block_mask=block_mask, score_mod=score_mod)
1019+
self.run_test(
1020+
Q_S=Q_S,
1021+
KV_S=KV_S,
1022+
block_mask=block_mask,
1023+
score_mod=score_mod,
1024+
device=device,
1025+
)
10201026

10211027
@supported_platform
10221028
@common_utils.parametrize("dtype", test_dtypes_fast)

0 commit comments

Comments
 (0)
0