8000 add float32,bf16 for xpu UTs · pytorch/pytorch@7e69c40 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7e69c40

Browse files
committed
add float32,bf16 for xpu UTs
1 parent ba427ee commit 7e69c40

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/inductor/test_flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def create_block_mask_test(score_mod, query, key):
117117
test_dtypes_fast = [torch.float16]
118118
elif HAS_XPU:
119119
test_device = "xpu"
120-
test_dtypes = [torch.float16]
120+
test_dtypes = [torch.float32, torch.bfloat16, torch.float16]
121121
test_dtypes_fast = [torch.float16]
122122
else:
123123
test_device = "cpu"

test/inductor/test_flex_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def create_block_mask_test(score_mod, query, key):
6464
test_dtypes_fast = [torch.float16]
6565
elif HAS_XPU:
6666
test_device = "xpu"
67-
test_dtypes = [torch.float16]
67+
test_dtypes = [torch.float32, torch.bfloat16, torch.float16]
6868
test_dtypes_fast = [torch.float16]
6969

7070
test_page_sizes = [64, 128, 256]

0 commit comments

Comments
 (0)
0