10000 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting acr… · pytorch/pytorch@3bde364 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3bde364

Browse files
eqypytorchmergebot
authored andcommitted
[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone Tests that should exercise both backends should explicitly parametrize this setting Pull Request resolved: #153655 Approved by: https://github.com/ngimel
1 parent 084c4aa commit 3bde364

File tree

1 file changed

+117
-97
lines changed

1 file changed

+117
-97
lines changed

test/test_matmul_cuda.py

Lines changed: 117 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@
6262
assert torch.get_default_dtype() is torch.float32
6363

6464

65+
@contextlib.contextmanager
66+
def blas_library_context(backend):
67+
prev_backend = torch.backends.cuda.preferred_blas_library()
68+
torch.backends.cuda.preferred_blas_library(backend)
69+
try:
70+
yield
71+
finally:
72+
torch.backends.cuda.preferred_blas_library(prev_backend)
73+
6574
class TestMatmulCuda(TestCase):
6675
def setUp(self):
6776
super().setUp()
@@ -141,8 +150,10 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool =
141150
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
142151
@dtypes(torch.float16, torch.bfloat16, torch.float32)
143152
@parametrize("size", [100, 1000, 10000])
144-
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
145-
self.cublas_addmm(size, dtype, False)
153+
@parametrize("backend", ["cublas", "cublaslt"])
154+
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
155+
with blas_library_context(backend):
156+
self.cublas_addmm(size, dtype, False)
146157

147158
@onlyCUDA
148159
@skipIfRocmVersionLessThan((5, 2))
@@ -151,8 +162,10 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype):
151162
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
152163
@dtypes(torch.float16, torch.bfloat16)
153164
@parametrize("size", [100, 1000, 10000])
154-
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
155-
self.cublas_addmm(size, dtype, True)
165+
@parametrize("backend", ["cublas", "cublaslt"])
166+
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
167+
with blas_library_context(backend):
168+
self.cublas_addmm(size, dtype, True)
156169

< 8000 /td>
157170
@onlyCUDA
158171
@skipIfRocmVersionLessThan((5, 2))
@@ -161,8 +174,10 @@ def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
161174
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
162175
@dtypes(torch.float16, torch.bfloat16)
163176
@parametrize("size", [100, 1000, 10000])
164-
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
165-
self.cublas_addmm(size, dtype, False, True)
177+
@parametrize("backend", ["cublas", "cublaslt"])
178+
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
179+
with blas_library_context(backend):
180+
self.cublas_addmm(size, dtype, False, True)
166181

167182
@onlyCUDA
168183
@skipIfRocm
@@ -456,49 +471,48 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
456471
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
457472
device = "cuda"
458473
dtype = input_dtype
459-
torch.backends.cuda.preferred_blas_library(backend)
460-
461-
def create_inputs(B=None):
462-
if B is None:
463-
a = torch.randn(M, K, device=device, dtype=dtype)
464-
b = torch.randn(K, N, device=device, dtype=dtype)
465-
else:
466-
a = torch.randn(B, M, K, device=device, dtype=dtype)
467-
b = torch.randn(B, K, N, device=device, dtype=dtype)
468-
return a, b
474+
with blas_library_context(backend):
475+
def create_inputs(B=None):
476+
if B is None:
477+
a = torch.randn(M, K, device=device, dtype=dtype)
478+
b = torch.randn(K, N, device=device, dtype=dtype)
479+
else:
480+
a = torch.randn(B, M, K, device=device, dtype=dtype)
481+
b = torch.randn(B, K, N, device=device, dtype=dtype)
482+
return a, b
469483

470-
a, b = create_inputs(batch_size)
484+
a, b = create_inputs(batch_size)
471485

472-
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
486+
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
473487

474-
output_dtypes = [torch.float32]
488+
output_dtypes = [torch.float32]
475489

476-
if input_dtype != torch.float32:
477-
output_dtypes.append(input_dtype)
490+
if input_dtype != torch.float32:
491+
output_dtypes.append(input_dtype)
478492

479-
for output_dtype in output_dtypes:
480-
# Catch edge case of incompat with bfloat16 and major version < 8
481-
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
482-
if output_dtype == torch.bfloat16:
483-
continue
493+
for output_dtype in output_dtypes:
494+
# Catch edge case of incompat with bfloat16 and major version < 8
495+
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
496+
if output_dtype == torch.bfloat16:
497+
continue
484498

485-
if batch_size:
486-
with self.assertRaises(RuntimeError):
487-
torch.bmm(a, b, out_dtype=output_dtype)
499+
if batch_size:
500+
with self.assertRaises(RuntimeError):
501+
torch.bmm(a, b, out_dtype=output_dtype)
502+
else:
503+
with self.assertRaises(RuntimeError):
504+
torch.mm(a, b, out_dtype=output_dtype)
488505
else:
489-
with self.assertRaises(RuntimeError):
490-
torch.mm(a, b, out_dtype=output_dtype)
491-
else:
492-
if batch_size:
493-
out = torch.bmm(a, b, out_dtype=output_dtype)
494-
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
495-
else:
496-
out = torch.mm(a, b, out_dtype=output_dtype)
497-
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
506+
if batch_size:
507+
out = torch.bmm(a, b, out_dtype=output_dtype)
508+
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
509+
else:
510+
out = torch.mm(a, b, out_dtype=output_dtype)
511+
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
498512

499-
self.assertEqual(out.dtype, output_dtype)
513+
self.assertEqual(out.dtype, output_dtype)
500514

501-
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
515+
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
502516

503517

504518
@onlyCUDA
@@ -512,51 +526,56 @@ def create_inputs(B=None):
512526
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
513527
device = "cuda"
514528
dtype = input_dtype
515-
torch.backends.cuda.preferred_blas_library(backend)
516-
517-
def create_inputs(B=None):
518-
if B is None:
519-
a = torch.randn(M, K, device=device, dtype=dtype)
520-
b = torch.randn(K, N, device=device, dtype=dtype)
521-
c = torch.randn(M, N, device=device, dtype=dtype)
522-
else:
523-
a = torch.randn(B, M, K, device=device, dtype=dtype)
524-
b = torch.randn(B, K, N, device=device, dtype=dtype)
525-
c = torch.randn(B, M, N, device=device, dtype=dtype)
529+
with blas_library_context(backend):
530+
def create_inputs(B=None):
531+
if B is None:
532+
a = torch.randn(M, K, device=device, dtype=dtype)
533+
b = torch.randn(K, N, device=device, dtype=dtype)
534+
c = torch.randn(M, N, device=device, dtype=dtype)
535+
else:
536+
a = torch.randn(B, M, K, device=device, dtype=dtype)
537+
b = torch.randn(B, K, N, device=device, dtype=dtype)
538+
c = torch.randn(B, M, N, device=device, dtype=dtype)
526539

527-
return a, b, c
540+
return a, b, c
528541

529-
a, b, c = create_inputs(batch_size)
542+
a, b, c = create_inputs(batch_size)
530543

531-
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
544+
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
532545

533-
output_dtypes = [torch.float32]
546+
output_dtypes = [torch.float32]
534547

535-
if input_dtype != torch.float32:
536-
output_dtypes.append(input_dtype)
548+
if input_dtype != torch.float32:
549+
output_dtypes.append(input_dtype)
537550

538-
for output_dtype in output_dtypes:
539-
# Catch edge case of incompat with bfloat16 and major version < 8
540-
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
541-
if output_dtype == torch.bfloat16:
542-
continue
551+
for output_dtype in output_dtypes:
552+
# Catch edge case of incompat with bfloat16 and major version < 8
553+
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
554+
if output_dtype == torch.bfloat16:
555+
continue
543556

544-
if batch_size:
545-
with self.assertRaises(RuntimeError):
546-
torch.baddbmm(c, a, b, out_dtype=output_dtype)
547-
else:
548-
with self.assertRaises(RuntimeError):
549-
torch.addmm(c, a, b, out_dtype=output_dtype)
550-
else:
551-
if batch_size:
552-
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
553-
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
557+
if batch_size:
558+
with self.assertRaises(RuntimeError):
559+
torch.baddbmm(c, a, b, out_dtype=output_dtype)
560+
else:
561+
with self.assertRaises(RuntimeError):
562+
10000 torch.addmm(c, a, b, out_dtype=output_dtype)
554563
else:
555-
out = torch.addmm(c, a, b, out_dtype=output_dtype)
556-
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)
557-
558-
self.assertEqual(out.dtype, output_dtype)
559-
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
564+
if batch_size:
565+
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
566+
if output_dtype == torch.float32:
567+
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
568+
else:
569+
baseline = torch.baddbmm(c, a, b)
570+
else:
571+
out = torch.addmm(c, a, b, out_dtype=output_dtype)
572+
if output_dtype == torch.float32:
573+
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
574+
else:
575+
baseline = torch.addmm(c, a, b)
576+
577+
self.assertEqual(out.dtype, output_dtype)
578+
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
560579

561580

562581
@onlyCUDA
@@ -567,35 +586,36 @@ def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
567586
M, N, K = 32, 32, 32
568587
device = "cuda"
569588
dtype = torch.float16
570-
torch.backends.cuda.preferred_blas_library(backend)
589+
with blas_library_context(backend):
590+
torch.backends.cuda.preferred_blas_library(backend)
571591

572-
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
573-
torch.backends.cuda.matmul.allow_fp16_accumulation = True
592+
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
593+
torch.backends.cuda.matmul.allow_fp16_accumulation = True
574594

575-
def create_inputs():
576-
a = torch.randn(M, K, device=device, dtype=dtype)
577-
b = torch.randn(K, N, device=device, dtype=dtype)
578-
c = torch.randn(M, N, device=device, dtype=dtype)
579-
return a, b, c
595+
def create_inputs():
596+
a = torch.randn(M, K, device=device, dtype=dtype)
597+
b = torch.randn(K, N, device=device, dtype=dtype)
598+
c = torch.randn(M, N, device=device, dtype=dtype)
599+
return a, b, c
580600

581-
def expand(tensor):
582-
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
601+
def expand(tensor):
602+
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
583603

584-
a, b, c = create_inputs()
604+
a, b, c = create_inputs()
585605

586-
with self.assertRaises(Exception):
587-
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
606+
with self.assertRaises(Exception):
607+
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
588608

589-
with self.assertRaises(Exception):
590-
torch.addmm(c, a, b, out_dtype=torch.float32)
609+
with self.assertRaises(Exception):
610+
torch.addmm(c, a, b, out_dtype=torch.float32)
591611

592-
with self.assertRaises(Exception):
593-
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
612+
with self.assertRaises(Exception):
613+
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
594614

595-
with self.assertRaises(Exception):
596-
torch.mm(a, b, out_dtype=torch.float32)
615+
with self.assertRaises(Exception):
616+
torch.mm(a, b, out_dtype=torch.float32)
597617

598-
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
618+
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
599619

600620
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
601621
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"

0 commit comments

Comments
 (0)
0