8000 fix gemm_add_and_bias · pytorch/pytorch@49bd078 · GitHub
[go: up one dir, main page]

Skip to content

Commit 49bd078

Browse files
committed
fix gemm_add_and_bias
1 parent 4e94a87 commit 49bd078

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

aten/src/ATen/cuda/tunable/GemmCublasLt.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,15 @@ cudaDataType_t GetBiasTypeFromParams(const GemmParams<T>* params) {
264264

265265
template <typename T>
266266
cudaDataType_t GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
267-
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
267+
if (std::is_same_v<T, double>) {
268+
return CUDA_R_64F;
269+
} else if (std::is_same_v<T, float>) {
270+
return CUDA_R_32F;
271+
} else if (std::is_same_v<T, Half>) {
272+
return CUDA_R_16F;
273+
} else if (std::is_same_v<T, BFloat16>) {
274+
return CUDA_R_16BF;
275+
}
268276
}
269277

270278
template <typename T>
@@ -677,7 +685,7 @@ auto GetCublasLtGemmTypeStringAndOps(const GemmParams<T>* params) {
677685
}
678686

679687
template <typename T, BlasOp ALayout, BlasOp BLayout>
680-
auto GetCublasLtGemmAndBiasTypeStringAndOps(const GemmParams<T>* params) {
688+
auto GetCublasLtGemmAndBiasTypeStringAndOps(const GemmAndBiasParams<T>* params) {
681689
return GetCublasLtTypeStringAndOps<T, ALayout, BLayout, GemmAndBiasParams<T>>(params);
682690
}
683691

0 commit comments

Comments
 (0)
0