62
62
assert torch .get_default_dtype () is torch .float32
63
63
64
64
65
+ @contextlib .contextmanager
66
+ def blas_library_context (backend ):
67
+ prev_backend = torch .backends .cuda .preferred_blas_library ()
68
+ torch .backends .cuda .preferred_blas_library (backend )
69
+ try :
70
+ yield
71
+ finally :
72
+ torch .backends .cuda .preferred_blas_library (prev_backend )
73
+
65
74
class TestMatmulCuda (TestCase ):
66
75
def setUp (self ):
67
76
super ().setUp ()
@@ -141,8 +150,10 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool =
141
150
torch .float32 : xtol (atol = 1e-1 , rtol = 1e-1 )})
142
151
@dtypes (torch .float16 , torch .bfloat16 , torch .float32 )
143
152
@parametrize ("size" , [100 , 1000 , 10000 ])
144
- def test_cublas_addmm (self , size : int , dtype : torch .dtype ):
145
- self .cublas_addmm (size , dtype , False )
153
+ @parametrize ("backend" , ["cublas" , "cublaslt" ])
154
+ def test_cublas_addmm (self , size : int , dtype : torch .dtype , backend ):
155
+ with blas_library_context (backend ):
156
+ self .cublas_addmm (size , dtype , False )
146
157
147
158
@onlyCUDA
148
159
@skipIfRocmVersionLessThan ((5 , 2 ))
@@ -151,8 +162,10 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype):
151
162
torch .bfloat16 : xtol (atol = 1e1 , rtol = 2e-1 )})
152
163
@dtypes (torch .float16 , torch .bfloat16 )
153
164
@parametrize ("size" , [100 , 1000 , 10000 ])
154
- def test_cublas_addmm_reduced_precision (self , size : int , dtype : torch .dtype ):
155
- self .cublas_addmm (size , dtype , True )
165
+ @parametrize ("backend" , ["cublas" , "cublaslt" ])
166
+ def test_cublas_addmm_reduced_precision (self , size : int , dtype : torch .dtype , backend ):
167
+ with blas_library_context (backend ):
168
+ self .cublas_addmm (size , dtype , True )
156
169
<
8000
/td>157
170
@onlyCUDA
158
171
@skipIfRocmVersionLessThan ((5 , 2 ))
@@ -161,8 +174,10 @@ def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
161
174
torch .bfloat16 : xtol (atol = 1e1 , rtol = 2e-1 )})
162
175
@dtypes (torch .float16 , torch .bfloat16 )
163
176
@parametrize ("size" , [100 , 1000 , 10000 ])
164
- def test_cublas_addmm_reduced_precision_fp16_accumulate (self , size : int , dtype : torch .dtype ):
165
- self .cublas_addmm (size , dtype , False , True )
177
+ @parametrize ("backend" , ["cublas" , "cublaslt" ])
178
+ def test_cublas_addmm_reduced_precision_fp16_accumulate (self , size : int , dtype : torch .dtype , backend ):
179
+ with blas_library_context (backend ):
180
+ self .cublas_addmm (size , dtype , False , True )
166
181
167
182
@onlyCUDA
168
183
@skipIfRocm
@@ -456,49 +471,48 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
456
471
def test_mm_bmm_dtype_overload (self , input_dtype , M , N , K , batch_size , backend ):
457
472
device = "cuda"
458
473
dtype = input_dtype
459
- torch .backends .cuda .preferred_blas_library (backend )
460
-
461
- def create_inputs (B = None ):
462
- if B is None :
463
- a = torch .randn (M , K , device = device , dtype = dtype )
464
- b = torch .randn (K , N , device = device , dtype = dtype )
465
- else :
466
- a = torch .randn (B , M , K , device = device , dtype = dtype )
467
- b = torch .randn (B , K , N , device = device , dtype = dtype )
468
- return a , b
474
+ with blas_library_context (backend ):
475
+ def create_inputs (B = None ):
476
+ if B is None :
477
+ a = torch .randn (M , K , device = device , dtype = dtype )
478
+ b = torch .randn (K , N , device = device , dtype = dtype )
479
+ else :
480
+ a = torch .randn (B , M , K , device = device , dtype = dtype )
481
+ b = torch .randn (B , K , N , device = device , dtype = dtype )
482
+ return a , b
469
483
470
- a , b = create_inputs (batch_size )
484
+ a , b = create_inputs (batch_size )
471
485
472
- a_fp32 , b_fp32 = a .to (torch .float32 ), b .to (torch .float32 )
486
+ a_fp32 , b_fp32 = a .to (torch .float32 ), b .to (torch .float32 )
473
487
474
- output_dtypes = [torch .float32 ]
488
+ output_dtypes = [torch .float32 ]
475
489
476
- if input_dtype != torch .float32 :
477
- output_dtypes .append (input_dtype )
490
+ if input_dtype != torch .float32 :
491
+ output_dtypes .append (input_dtype )
478
492
479
- for output_dtype in output_dtypes :
480
- # Catch edge case of incompat with bfloat16 and major version < 8
481
- if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
482
- if output_dtype == torch .bfloat16 :
483
- continue
493
+ for output_dtype in output_dtypes :
494
+ # Catch edge case of incompat with bfloat16 and major version < 8
495
+ if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
496
+ if output_dtype == torch .bfloat16 :
497
+ continue
484
498
485
- if batch_size :
486
- with self .assertRaises (RuntimeError ):
487
- torch .bmm (a , b , out_dtype = output_dtype )
499
+ if batch_size :
500
+ with self .assertRaises (RuntimeError ):
501
+ torch .bmm (a , b , out_dtype = output_dtype )
502
+ else :
503
+ with self .assertRaises (RuntimeError ):
504
+ torch .mm (a , b , out_dtype = output_dtype )
488
505
else :
489
- with self .assertRaises (RuntimeError ):
490
- torch .mm (a , b , out_dtype = output_dtype )
491
- else :
492
- if batch_size :
493
- out = torch .bmm (a , b , out_dtype = output_dtype )
494
- baseline = torch .bmm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .bmm (a , b )
495
- else :
496
- out = torch .mm (a , b , out_dtype = output_dtype )
497
- baseline = torch .mm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .mm (a , b )
506
+ if batch_size :
507
+ out = torch .bmm (a , b , out_dtype = output_dtype )
508
+ baseline = torch .bmm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .bmm (a , b )
509
+ else :
510
+ out = torch .mm (a , b , out_dtype = output_dtype )
511
+ baseline = torch .mm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .mm (a , b )
498
512
499
- self .assertEqual (out .dtype , output_dtype )
513
+ self .assertEqual (out .dtype , output_dtype )
500
514
501
- torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
515
+ torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
502
516
503
517
504
518
@onlyCUDA
@@ -512,51 +526,56 @@ def create_inputs(B=None):
512
526
def test_addmm_baddmm_dtype_overload (self , input_dtype , M , N , K , batch_size , backend ):
513
527
device = "cuda"
514
528
dtype = input_dtype
515
- torch .backends .cuda .preferred_blas_library (backend )
516
-
517
- def create_inputs (B = None ):
518
- if B is None :
519
- a = torch .randn (M , K , device = device , dtype = dtype )
520
- b = torch .randn (K , N , device = device , dtype = dtype )
521
- c = torch .randn (M , N , device = device , dtype = dtype )
522
- else :
523
- a = torch .randn (B , M , K , device = device , dtype = dtype )
524
- b = torch .randn (B , K , N , device = device , dtype = dtype )
525
- c = torch .randn (B , M , N , device = device , dtype = dtype )
529
+ with blas_library_context (backend ):
530
+ def create_inputs (B = None ):
531
+ if B is None :
532
+ a = torch .randn (M , K , device = device , dtype = dtype )
533
+ b = torch .randn (K , N , device = device , dtype = dtype )
534
+ c = torch .randn (M , N , device = device , dtype = dtype )
535
+ else :
536
+ a = torch .randn (B , M , K , device = device , dtype = dtype )
537
+ b = torch .randn (B , K , N , device = device , dtype = dtype )
538
+ c = torch .randn (B , M , N , device = device , dtype = dtype )
526
539
527
- return a , b , c
540
+ return a , b , c
528
541
529
- a , b , c = create_inputs (batch_size )
542
+ a , b , c = create_inputs (batch_size )
530
543
531
- a_fp32 , b_fp32 , c_fp32 = a .to (torch .float32 ), b .to (torch .float32 ), c .to (torch .float32 )
544
+ a_fp32 , b_fp32 , c_fp32 = a .to (torch .float32 ), b .to (torch .float32 ), c .to (torch .float32 )
532
545
533
- output_dtypes = [torch .float32 ]
546
+ output_dtypes = [torch .float32 ]
534
547
535
- if input_dtype != torch .float32 :
536
- output_dtypes .append (input_dtype )
548
+ if input_dtype != torch .float32 :
549
+ output_dtypes .append (input_dtype )
537
550
538
- for output_dtype in output_dtypes :
539
- # Catch edge case of incompat with bfloat16 and major version < 8
540
- if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
541
- if output_dtype == torch .bfloat16 :
542
- continue
551
+ for output_dtype in output_dtypes :
552
+ # Catch edge case of incompat with bfloat16 and major version < 8
553
+ if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
554
+ if output_dtype == torch .bfloat16 :
555
+ continue
543
556
544
- if batch_size :
545
- with self .assertRaises (RuntimeError ):
546
- torch .baddbmm (c , a , b , out_dtype = output_dtype )
547
- else :
548
- with self .assertRaises (RuntimeError ):
549
- torch .addmm (c , a , b , out_dtype = output_dtype )
550
- else :
551
- if batch_size :
552
- out = torch .baddbmm (c , a , b , out_dtype = output_dtype )
553
- baseline = torch .baddbmm (c_fp32 , a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .baddbmm (c , a , b )
557
+ if batch_size :
558
+ with self .assertRaises (RuntimeError ):
559
+ torch .baddbmm (c , a , b , out_dtype = output_dtype )
560
+ else :
561
+ with self .assertRaises (RuntimeError ):
562
+
10000
torch .addmm (c , a , b , out_dtype = output_dtype )
554
563
else :
555
- out = torch .addmm (c , a , b , out_dtype = output_dtype )
556
- baseline = torch .addmm (c_fp32 , a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .addmm (c , a , b )
557
-
558
- self .assertEqual (out .dtype , output_dtype )
559
- torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
564
+ if batch_size :
565
+ out = torch .baddbmm (c , a , b , out_dtype = output_dtype )
566
+ if output_dtype == torch .float32 :
567
+ baseline = torch .baddbmm (c_fp32 , a_fp32 , b_fp32 )
568
+ else :
569
+ baseline = torch .baddbmm (c , a , b )
570
+ else :
571
+ out = torch .addmm (c , a , b , out_dtype = output_dtype )
572
+ if output_dtype == torch .float32 :
573
+ baseline = torch .addmm (c_fp32 , a_fp32 , b_fp32 )
574
+ else :
575
+ baseline = torch .addmm (c , a , b )
576
+
577
+ self .assertEqual (out .dtype , output_dtype )
578
+ torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
560
579
561
580
562
581
@onlyCUDA
@@ -567,35 +586,36 @@ def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
567
586
M , N , K = 32 , 32 , 32
568
587
device = "cuda"
569
588
dtype = torch .float16
570
- torch .backends .cuda .preferred_blas_library (backend )
589
+ with blas_library_context (backend ):
590
+ torch .backends .cuda .preferred_blas_library (backend )
571
591
572
- orig_fp16_accum = torch .backends .cuda .matmul .allow_fp16_accumulation
573
- torch .backends .cuda .matmul .allow_fp16_accumulation = True
592
+ orig_fp16_accum = torch .backends .cuda .matmul .allow_fp16_accumulation
593
+ torch .backends .cuda .matmul .allow_fp16_accumulation = True
574
594
575
- def create_inputs ():
576
- a = torch .randn (M , K , device = device , dtype = dtype )
577
- b = torch .randn (K , N , device = device , dtype = dtype )
578
- c = torch .randn (M , N , device = device , dtype = dtype )
579
- return a , b , c
595
+ def create_inputs ():
596
+ a = torch .randn (M , K , device = device , dtype = dtype )
597
+ b = torch .randn (K , N , device = device , dtype = dtype )
598
+ c = torch .randn (M , N , device = device , dtype = dtype )
599
+ return a , b , c
580
600
581
- def expand (tensor ):
582
- return tensor .unsqueeze (0 ).expand (batch_size , * tensor .shape )
601
+ def expand (tensor ):
602
+ return tensor .unsqueeze (0 ).expand (batch_size , * tensor .shape )
583
603
584
- a , b , c = create_inputs ()
604
+ a , b , c = create_inputs ()
585
605
586
- with self .assertRaises (Exception ):
587
- torch .baddbmm (expand (c ), expand (a ), expand (b ), out_dtype = torch .float32 )
606
+ with self .assertRaises (Exception ):
607
+ torch .baddbmm (expand (c ), expand (a ), expand (b ), out_dtype = torch .float32 )
588
608
589
- with self .assertRaises (Exception ):
590
- torch .addmm (c , a , b , out_dtype = torch .float32 )
609
+ with self .assertRaises (Exception ):
610
+ torch .addmm (c , a , b , out_dtype = torch .float32 )
591
611
592
- with self .assertRaises (Exception ):
593
- torch .bmm (expand (a ,), expand (b ), out_dtype = torch .float32 )
612
+ with self .assertRaises (Exception ):
613
+ torch .bmm (expand (a ,), expand (b ), out_dtype = torch .float32 )
594
614
595
- with self .assertRaises (Exception ):
596
- torch .mm (a , b , out_dtype = torch .float32 )
615
+ with self .assertRaises (Exception ):
616
+ torch .mm (a , b , out_dtype = torch .float32 )
597
617
598
- torch .backends .cuda .matmul .allow_fp16_accumulation = orig_fp16_accum
618
+ torch .backends .cuda .matmul .allow_fp16_accumulation = orig_fp16_accum
599
619
600
620
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
601
621
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
0 commit comments