8000 ROCm: Enable overload tests from test_matmul_cuda (#161540) · pytorch/pytorch@a8d6943 · GitHub
[go: up one dir, main page]

Skip to content

Commit a8d6943

Browse files
jagadish-amdpytorchmergebot
authored andcommitted
ROCm: Enable overload tests from test_matmul_cuda (#161540)
This patch enables hipblaslt backend tests for test_mm_bmm_dtype_overload and test_addmm_baddmm_dtype_overload. Tests were disabled as part of #150812 Rocblas backend tests are not enabled yet, WIP. Test command PYTORCH_TEST_WITH_ROCM=1 pytest test/test_matmul_cuda.py -k 'test_mm_bmm_dtype_overload' -v PYTORCH_TEST_WITH_ROCM=1 pytest test/test_matmul_cuda.py -k 'test_addmm_baddmm_dtype_overload' -v Pull Request resolved: #161540 Approved by: https://github.com/jeffdaily
1 parent d1172
8000
0e commit a8d6943

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -996,19 +996,14 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
996996

997997
template <>
998998
void bgemm<at::Half, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) {
999-
#ifdef USE_ROCM
1000-
TORCH_CHECK(false, "bgemm input type at::Half and output type float is not supported for ROCm");
1001-
#endif
1002999
// TODO: Support tuning for Half inputs and FP32 output
10031000
bgemm_internal<at::Half, float>(CUDABLAS_BGEMM_ARGS(at::Half));
10041001
}
10051002

10061003

10071004
template <>
10081005
void bgemm<at::BFloat16, float>(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)) {
1009-
#ifdef USE_ROCM
1010-
TORCH_CHECK(false, "bgemm input type at::BFloat16 and output type float is not supported for ROCm");
1011-
#else
1006+
#ifndef USE_ROCM
10121007
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
10131008

10141009
if (prop->major < 8)
@@ -1513,19 +1508,14 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
15131508

15141509
template <>
15151510
void gemm<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half, float)) {
1516-
#ifdef USE_ROCM
1517-
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
1518-
#endif
15191511
// TODO: Support Tuning for fp16-fp32 gemm
15201512
gemm_internal<at::Half, float>(CUDABLAS_GEMM_ARGS(at::Half));
15211513
}
15221514

15231515

15241516
template <>
15251517
void gemm<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::BFloat16, float)) {
1526-
#ifdef USE_ROCM
1527-
TORCH_CHECK(false, "gemm input type at::BFloat16 and output type float is not supported for ROCm");
1528-
#else
1518+
#ifndef USE_ROCM
15291519
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
15301520

15311521
if (prop->major < 8)

test/test_matmul_cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,13 +612,13 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune)
612612

613613

614614
@onlyCUDA
615-
@skipIfRocm
616615
@parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
617616
@parametrize("M", [1, 32, 64])
618617
@parametrize("N", [1, 32, 64])
619618
@parametrize("K", [1, 32, 64])
620619
@parametrize("batch_size", [None, 1, 16])
621-
@parametrize("backend", ["cublas", "cublaslt"])
620+
# TODO: enable rocblas path on ROCm
621+
@parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"])
622622
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
623623
device = "cuda"
624624
dtype = input_dtype
@@ -667,13 +667,13 @@ def create_inputs(B=None):
667667

668668

669669
@onlyCUDA
670-
@skipIfRocm
671670
@parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16])
672671
@parametrize("M", [1, 32, 64])
673672
@parametrize("N", [1, 32, 64])
674673
@parametrize("K", [1, 32, 64])
675674
@parametrize("batch_size", [None, 1, 32])
676-
@parametrize("backend", ["cublas", "cublaslt"])
675+
# TODO: enable rocblas path on ROCm
676+
@parametrize("backend", ["cublaslt"] if torch.version.hip else ["cublas", "cublaslt"])
677677
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
678678
device = "cuda"
679679
dtype = input_dtype

0 commit comments

Comments
 (0)
0