@@ -2973,6 +2973,7 @@ def head_bias(score, b, h, q_idx, kv_idx):
2973
2973
bias_sdpa_ref ,
2974
2974
implicit_bias_sdpa_gold ,
2975
2975
bias_sdpa_gold ,
2976
+ device = device ,
2976
2977
)
2977
2978
2978
2979
@supported_platform
@@ -3008,6 +3009,7 @@ def rel_pos_1d(score, b, h, q_idx, kv_idx):
3008
3009
bias_sdpa_ref ,
3009
3010
implicit_bias_sdpa_gold ,
3010
3011
bias_sdpa_gold ,
3012
+ device = device ,
3011
3013
)
3012
3014
3013
3015
# 2-dimensional bias:
@@ -3037,6 +3039,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
3037
3039
bias_sdpa_ref ,
3038
3040
implicit_bias_sdpa_gold ,
3039
3041
bias_sdpa_gold ,
3042
+ device = device ,
3040
3043
)
3041
3044
3042
3045
# 2-dimensional bias + index multiple
@@ -3066,6 +3069,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
3066
3069
bias_sdpa_ref ,
3067
3070
implicit_bias_sdpa_gold ,
3068
3071
bias_sdpa_gold ,
3072
+ device = device ,
3069
3073
)
3070
3074
3071
3075
# 2-dimensional bias + transposed:
@@ -3095,6 +3099,7 @@ def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx):
3095
3099
bias_sdpa_ref ,
3096
3100
implicit_bias_sdpa_gold ,
3097
3101
bias_sdpa_gold ,
3102
+ device = device ,
3098
3103
)
3099
3104
3100
3105
# 3-dimens
E174
ional bias + transposed
@@ -3126,11 +3131,11 @@ def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx):
3126
3131
bias_sdpa_ref ,
3127
3132
implicit_bias_sdpa_gold ,
3128
3133
bias_sdpa_gold ,
3134
+ device = device ,
3129
3135
)
3130
3136
3131
3137
def _test_learnable_bias_inner (
3132
3138
self ,
3133
- device ,
3134
3139
B ,
3135
3140
H ,
3136
3141
S ,
@@ -3141,6 +3146,7 @@ def _test_learnable_bias_inner(
3141
3146
bias_sdpa_ref ,
3142
3147
implicit_bias_sdpa_gold ,
3143
3148
bias_sdpa_gold ,
3149
+ device : str = "cuda" ,
3144
3150
):
3145
3151
make_tensor = functools .partial (
3146
3152
torch .ones ,
@@ -3850,7 +3856,7 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
3850
3856
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
3851
3857
return full
3852
3858
""" .replace ( # noqa: B950
3853
- "GPU_TYPE" , device
3859
+ "GPU_TYPE" , torch . device ( device ). type
3854
3860
)
3855
3861
3856
3862
self .assertExpectedInline (
0 commit comments