8000 Revert "[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt set… · pytorch/pytorch@40339c1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 40339c1

Browse files
Revert "[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)"
This reverts commit 3bde364. Reverted #153655 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](#153655 (comment)))
1 parent 9b2a45a commit 40339c1

File tree

1 file changed

+97
-117
lines changed

1 file changed

+97
-117
lines changed

test/test_matmul_cuda.py

Lines changed: 97 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,6 @@
6262
assert torch.get_default_dtype() is torch.float32
6363

6464

65-
@contextlib.contextmanager
66-
def blas_library_context(backend):
67-
prev_backend = torch.backends.cuda.preferred_blas_library()
68-
torch.backends.cuda.preferred_blas_library(backend)
69-
try:
70-
yield
71-
finally:
72-
torch.backends.cuda.preferred_blas_library(prev_backend)
73-
7465
class TestMatmulCuda(TestCase):
7566
def setUp(self):
7667
super().setUp()
@@ -150,10 +141,8 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool =
150141
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
151142
@dtypes(torch.float16, torch.bfloat16, torch.float32)
152143
@parametrize("size", [100, 1000, 10000])
153-
@parametrize("backend", ["cublas", "cublaslt"])
154-
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
155-
with blas_library_context(backend):
156-
self.cublas_addmm(size, dtype, False)
144+
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
145+
self.cublas_addmm(size, dtype, False)
157146

158147
@onlyCUDA
159148
@skipIfRocmVersionLessThan((5, 2))
@@ -162,10 +151,8 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
162151
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
163152
@dtypes(torch.float16, torch.bfloat16)
164153
@parametrize("size", [100, 1000, 10000])
165-
@parametrize("backend", ["cublas", "cublaslt"])
166-
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
167-
with blas_library_context(backend):
168-
self.cublas_addmm(size, dtype, True)
154+
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
155+
self.cublas_addmm(size, dtype, True)
169156

170157
@onlyCUDA
171158
@skipIfRocmVersionLessThan((5, 2))
@@ -174,10 +161,8 @@ def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, bac
174161
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
175162
@dtypes(torch.float16, torch.bfloat16)
176163
@parametrize("size", [100, 1000, 10000])
177-
@parametrize("backend", ["cublas", "cublaslt"])
178-
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
179-
with blas_library_context(backend):
180-
self.cublas_addmm(size, dtype, False, True)
164+
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
165+
self.cublas_addmm(size, dtype, False, True)
181166

182167
@onlyCUDA
183168
@skipIfRocm
@@ -471,48 +456,49 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
471456
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
472457
device = "cuda"
473458
dtype = input_dtype
474-
with blas_library_context(backend):
475-
def create_inputs(B=None):
476-
if B is None:
477-
a = torch.randn(M, K, device=device, dtype=dtype)
478-
b = torch.randn(K, N, device=device, dtype=dtype)
479-
else:
480-
a = torch.randn(B, M, K, device=device, dtype=dtype)
481-
b = torch.randn(B, K, N, device=device, dtype=dtype)
482-
return a, b
459+
torch.backends.cuda.preferred_blas_library(backend)
483460

484-
a, b = create_inputs(batch_size)
461+
def create_inputs(B=None):
462+
if B is None:
463+
a = torch.randn(M, K, device=device, dtype=dtype)
464+
b = torch.randn(K, N, device=device, dtype=dtype)
465+
else:
466+
a = torch.randn(B, M, K, device=device, dtype=dtype)
467+
b = torch.randn(B, K, N, device=device, dtype=dtype)
468+
return a, b
469+
470+
a, b = create_inputs(batch_size)
485471

486-
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
472+
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
487473

488-
output_dtypes = [torch.float32]
474+
output_dtypes = [torch.float32]
489475

490-
if input_dtype != torch.float32:
491-
output_dtypes.append(input_dtype)
476+
if input_dtype != torch.float32:
477+
output_dtypes.append(input_dtype)
492478

493-
for output_dtype in output_dtypes:
494-
# Catch edge case of incompat with bfloat16 and major version < 8
495-
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
496-
if output_dtype == torch.bfloat16:
497-
continue
479+
for output_dtype in output_dtypes:
480+
# Catch edge case of incompat with bfloat16 and major version < 8
481+
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
482+
if output_dtype == torch.bfloat16:
483+
continue
498484

499-
if batch_size:
500-
with self.assertRaises(RuntimeError):
501-
torch.bmm(a, b, out_dtype=output_dtype)
502-
else:
503-
with self.assertRaises(RuntimeError):
504-
torch.mm(a, b, out_dtype=output_dtype)
485+
if batch_size:
486+
with self.assertRaises(RuntimeError):
487+
torch.bmm(a, b, out_dtype=output_dtype)
505488
else:
506-
if batch_size:
507-
out = torch.bmm(a, b, out_dtype=output_dtype)
508-
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
509-
else:
510-
out = torch.mm(a, b, out_dtype=output_dtype)
511-
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
489+
with self.assertRaises(RuntimeError):
490+
torch.mm(a, b, out_dtype=output_dtype)
491+
else:
492+
if batch_size:
493+
out = torch.bmm(a, b, out_dtype=output_dtype)
494+
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
495+
else:
496+
out = torch.mm(a, b, out_dtype=output_dtype)
497+
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
512498

513-
self.assertEqual(out.dtype, output_dtype)
499+
self.assertEqual(out.dtype, output_dtype)
514500

515-
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
501+
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
516502

517503

518504
@onlyCUDA
@@ -526,56 +512,51 @@ def create_inputs(B=None):
526512
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
527513
device = "cuda"
528514
dtype = input_dtype
529-
with blas_library_context(backend):
530-
def create_inputs(B=None):
531-
if B is None:
532-
a = torch.randn(M, K, device=device, dtype=dtype)
533-
b = torch.randn(K, N, device=device, dtype=dtype)
534-
c = torch.randn(M, N, device=device, dtype=dtype)
535-
else:
536-
a = torch.randn(B, M, K, device=device, dtype=dtype)
537-
b = torch.randn(B, K, N, device=device, dtype=dtype)
538-
c = torch.randn(B, M, N, device=device, dtype=dtype)
515+
torch.backends.cuda.preferred_blas_library(backend)
539516

540-
return a, b, c
517+
def create_inputs(B=None):
518+
if B is None:
519+
a = torch.randn(M, K, device=device, dtype=dtype)
520+
b = torch.randn(K, N, device=device, dtype=dtype)
521+
c = torch.randn(M, N, device=device, dtype=dtype)
522+
else:
523+
a = torch.randn(B, M, K, device=device, dtype=dtype)
524+
b = torch.randn(B, K, N, device=device, dtype=dtype)
525+
c = torch.randn(B, M, N, device=device, dtype=dtype)
541526

542-
a, b, c = create_inputs(batch_size)
527+
return a, b, c
543528

544-
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
529+
a, b, c = create_inputs(batch_size)
545530

546-
output_dtypes = [torch.float32]
531+
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
547532

548-
if input_dtype != torch.float32:
549-
output_dtypes.append(input_dtype)
533+
output_dtypes = [torch.float32]
550534

551-
for output_dtype in output_dtypes:
552-
# Catch edge case of incompat with bfloat16 and major version < 8
553-
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
554-
if output_dtype == torch.bfloat16:
555-
continue
535+
if input_dtype != torch.float32:
536+
output_dtypes.append(input_dtype)
556537

557-
if batch_size:
558-
with self.assertRaises(RuntimeError):
559-
torch.baddbmm(c, a, b, out_dtype=output_dtype)
560-
else:
561-
with self.assertRaises(RuntimeError):
562-
torch.addmm(c, a, b, out_dtype=output_dtype)
538+
for output_dtype in output_dtypes:
539+
# Catch edge case of incompat with bfloat16 and major version < 8
540+
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
541+
if output_dtype == torch.bfloat16:
542+
continue
543+
544+
if batch_size:
545+
with self.assertRaises(RuntimeError):
546+
torch.baddbmm(c, a, b, out_dtype=output_dtype)
547+
else:
548+
with self.assertRaises(RuntimeError):
549+
torch.addmm(c, a, b, out_dtype=output_dtype)
550+
else:
551+
if batch_size:
552+
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
553+
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
563554
else:
564-
if batch_size:
565-
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
566-
if output_dtype == torch.float32:
567-
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
568-
else:
569-
baseline = torch.baddbmm(c, a, b)
570-
else:
571-
out = torch.addmm(c, a, b, out_dtype=output_dtype)
572-
if output_dtype == torch.float32:
573-
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
574-
else:
575-
baseline = torch.addmm(c, a, b)
576-
577-
self.assertEqual(out.dtype, output_dtype)
578-
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
555+
out = torch.addmm(c, a, b, out_dtype=output_dtype)
556+
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)
557+
558+
self.assertEqual(out.dtype, output_dtype)
559+
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
579560

580561

581562
@onlyCUDA
@@ -586,36 +567,35 @@ def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
586567
M, N, K = 32, 32, 32
587568
device = "cuda"
588569
dtype = torch.float16
589-
with blas_library_context(backend):
590-
torch.backends.cuda.preferred_blas_library(backend)
570+
torch.backends.cuda.preferred_blas_library(backend)
591571

592-
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
593-
torch.backends.cuda.matmul.allow_fp16_accumulation = True
572+
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
573+
torch.backends.cuda.matmul.allow_fp16_accumulation = True
594574

595-
def create_inputs():
596-
a = torch.randn(M, K, device=device, dtype=dtype)
597-
b = torch.randn(K, N, device=device, dtype=dtype)
598-
c = torch.randn(M, N, device=device, dtype=dtype)
599-
return a, b, c
575+
def create_inputs():
576+
a = torch.randn(M, K, device=device, dtype=dtype)
577+
b = torch.randn(K, N, device=device, dtype=dtype)
578+
c = torch.randn(M, N, device=device, dtype=dtype)
579+
return a, b, c
600580

601-
def expand(tensor):
602-
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
581+
def expand(tensor):
582+
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
603583

604-
a, b, c = create_inputs()
584+
a, b, c = create_inputs()
605585

606-
with self.assertRaises(Exception):
607-
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
586+
with self.assertRaises(Exception):
587+
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
608588

609-
with self.assertRaises(Exception):
610-
torch.addmm(c, a, b, out_dtype=torch.float32)
589+
with self.assertRaises(Exception):
590+
torch.addmm(c, a, b, out_dtype=torch.float32)
611591

612-
with self.assertRaises(Exception):
613-
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
592+
with self.assertRaises(Exception):
593+
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
614594

615-
with self.assertRaises(Exception):
616-
torch.mm(a, b, out_dtype=torch.float32)
595+
with self.assertRaises(Exception):
596+
torch.mm(a, b, out_dtype=torch.float32)
617597

618-
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
598+
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
619599

620600
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
621601
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"

0 commit comments

Comments
 (0)
0