8000 fix merged uts · pytorch/pytorch@377eef5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 377eef5

Browse files
committed
fix merged uts
1 parent a12cc6f commit 377eef5

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

test/inductor/test_flex_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,6 +2973,7 @@ def head_bias(score, b, h, q_idx, kv_idx):
29732973
bias_sdpa_ref,
29742974
implicit_bias_sdpa_gold,
29752975
bias_sdpa_gold,
2976+
device=device,
29762977
)
29772978

29782979
@supported_platform
@@ -3008,6 +3009,7 @@ def rel_pos_1d(score, b, h, q_idx, kv_idx):
30083009
bias_sdpa_ref,
30093010
implicit_bias_sdpa_gold,
30103011
bias_sdpa_gold,
3012+
device=device,
30113013
)
30123014

30133015
# 2-dimensional bias:
@@ -3037,6 +3039,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
30373039
bias_sdpa_ref,
30383040
implicit_bias_sdpa_gold,
30393041
bias_sdpa_gold,
3042+
device=device,
30403043
)
30413044

30423045
# 2-dimensional bias + index multiple
@@ -3066,6 +3069,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
30663069
bias_sdpa_ref,
30673070
implicit_bias_sdpa_gold,
30683071
bias_sdpa_gold,
3072+
device=device,
30693073
)
30703074

30713075
# 2-dimensional bias + transposed:
@@ -3095,6 +3099,7 @@ def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx):
30953099
bias_sdpa_ref,
30963100
implicit_bias_sdpa_gold,
30973101
bias_sdpa_gold,
3102+
device=device,
30983103
)
30993104

31003105
# 3-dimens E174 ional bias + transposed
@@ -3126,11 +3131,11 @@ def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx):
31263131
bias_sdpa_ref,
31273132
implicit_bias_sdpa_gold,
31283133
bias_sdpa_gold,
3134+
device=device,
31293135
)
31303136

31313137
def _test_learnable_bias_inner(
31323138
self,
3133-
device,
31343139
B,
31353140
H,
31363141
S,
@@ -3141,6 +3146,7 @@ def _test_learnable_bias_inner(
31413146
bias_sdpa_ref,
31423147
implicit_bias_sdpa_gold,
31433148
bias_sdpa_gold,
3149+
device: str = "cuda",
31443150
):
31453151
make_tensor = functools.partial(
31463152
torch.ones,
@@ -3850,7 +3856,7 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
38503856
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
38513857
return full
38523858
""".replace( # noqa: B950
3853-
"GPU_TYPE", device
3859+
"GPU_TYPE", torch.device(device).type
38543860
)
38553861

38563862
self.assertExpectedInline(

0 commit comments

Comments
 (0)
0