62
62
assert torch .get_default_dtype () is torch .float32
63
63
64
64
65
- @contextlib .contextmanager
66
- def blas_library_context (backend ):
67
- prev_backend = torch .backends .cuda .preferred_blas_library ()
68
- torch .backends .cuda .preferred_blas_library (backend )
69
- try :
70
- yield
71
- finally :
72
- torch .backends .cuda .preferred_blas_library (prev_backend )
73
-
74
65
class TestMatmulCuda (TestCase ):
75
66
def setUp (self ):
76
67
super ().setUp ()
@@ -150,10 +141,8 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool =
150
141
torch .float32 : xtol (atol = 1e-1 , rtol = 1e-1 )})
151
142
@dtypes (torch .float16 , torch .bfloat16 , torch .float32 )
152
143
@parametrize ("size" , [100 , 1000 , 10000 ])
153
- @parametrize ("backend" , ["cublas" , "cublaslt" ])
154
- def test_cublas_addmm (self , size : int , dtype : torch .dtype , backend ):
155
- with blas_library_context (backend ):
156
- self .cublas_addmm (size , dtype , False )
144
+ def test_cublas_addmm (self , size : int , dtype : torch .dtype ):
145
+ self .cublas_addmm (size , dtype , False )
157
146
158
147
@onlyCUDA
159
148
@skipIfRocmVersionLessThan ((5 , 2 ))
@@ -162,10 +151,8 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
162
151
torch .bfloat16 : xtol (atol = 1e1 , rtol = 2e-1 )})
163
152
@dtypes (torch .float16 , torch .bfloat16 )
164
153
@parametrize ("size" , [100 , 1000 , 10000 ])
165
- @parametrize ("backend" , ["cublas" , "cublaslt" ])
166
- def test_cublas_addmm_reduced_precision (self , size : int , dtype : torch .dtype , backend ):
167
- with blas_library_context (backend ):
168
- self .cublas_addmm (size , dtype , True )
154
+ def test_cublas_addmm_reduced_precision (self , size : int , dtype : torch .dtype ):
155
+ self .cublas_addmm (size , dtype , True )
169
156
170
157
@onlyCUDA
171
158
@skipIfRocmVersionLessThan ((5 , 2 ))
@@ -174,10 +161,8 @@ def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, bac
174
161
torch .bfloat16 : xtol (atol = 1e1 , rtol = 2e-1 )})
175
162
@dtypes (torch .float16 , torch .bfloat16 )
176
163
@parametrize ("size" , [100 , 1000 , 10000 ])
177
- @parametrize ("backend" , ["cublas" , "cublaslt" ])
178
- def test_cublas_addmm_reduced_precision_fp16_accumulate (self , size : int , dtype : torch .dtype , backend ):
179
- with blas_library_context (backend ):
180
- self .cublas_addmm (size , dtype , False , True )
164
+ def test_cublas_addmm_reduced_precision_fp16_accumulate (self , size : int , dtype : torch .dtype ):
165
+ self .cublas_addmm (size , dtype , False , True )
181
166
182
167
@onlyCUDA
183
168
@skipIfRocm
@@ -471,48 +456,49 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
471
456
def test_mm_bmm_dtype_overload (self , input_dtype , M , N , K , batch_size , backend ):
472
457
device = "cuda"
473
458
dtype = input_dtype
474
- with blas_library_context (backend ):
475
- def create_inputs (B = None ):
476
- if B is None :
477
- a = torch .randn (M , K , device = device , dtype = dtype )
478
- b = torch .randn (K , N , device = device , dtype = dtype )
479
- else :
480
- a = torch .randn (B , M , K , device = device , dtype = dtype )
481
- b = torch .randn (B , K , N , device = device , dtype = dtype )
482
- return a , b
459
+ torch .backends .cuda .preferred_blas_library (backend )
483
460
484
- a , b = create_inputs (batch_size )
461
+ def create_inputs (B = None ):
462
+ if B is None :
463
+ a = torch .randn (M , K , device = device , dtype = dtype )
464
+ b = torch .randn (K , N , device = device , dtype = dtype )
465
+ else :
466
+ a = torch .randn (B , M , K , device = device , dtype = dtype )
467
+ b = torch .randn (B , K , N , device = device , dtype = dtype )
468
+ return a , b
469
+
470
+ a , b = create_inputs (batch_size )
485
471
486
- a_fp32 , b_fp32 = a .to (torch .float32 ), b .to (torch .float32 )
472
+ a_fp32 , b_fp32 = a .to (torch .float32 ), b .to (torch .float32 )
487
473
488
- output_dtypes = [torch .float32 ]
474
+ output_dtypes = [torch .float32 ]
489
475
490
- if input_dtype != torch .float32 :
491
- output_dtypes .append (input_dtype )
476
+ if input_dtype != torch .float32 :
477
+ output_dtypes .append (input_dtype )
492
478
493
- for output_dtype in output_dtypes :
494
- # Catch edge case of incompat with bfloat16 and major version < 8
495
- if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
496
- if output_dtype == torch .bfloat16 :
497
- continue
479
+ for output_dtype in output_dtypes :
480
+ # Catch edge case of incompat with bfloat16 and major version < 8
481
+ if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
482
+ if output_dtype == torch .bfloat16 :
483
+ continue
498
484
499
- if batch_size :
500
- with self .assertRaises (RuntimeError ):
501
- torch .bmm (a , b , out_dtype = output_dtype )
502
- else :
503
- with self .assertRaises (RuntimeError ):
504
- torch .mm (a , b , out_dtype = output_dtype )
485
+ if batch_size :
486
+ with self .assertRaises (RuntimeError ):
487
+ torch .bmm (a , b , out_dtype = output_dtype )
505
488
else :
506
- if batch_size :
507
- out = torch .bmm (a , b , out_dtype = output_dtype )
508
- baseline = torch .bmm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .bmm (a , b )
509
- else :
510
- out = torch .mm (a , b , out_dtype = output_dtype )
511
- baseline = torch .mm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .mm (a , b )
489
+ with self .assertRaises (RuntimeError ):
490
+ torch .mm (a , b , out_dtype = output_dtype )
491
+ else :
492
+ if batch_size :
493
+ out = torch .bmm (a , b , out_dtype = output_dtype )
494
+ baseline = torch .bmm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .bmm (a , b )
495
+ else :
496
+ out = torch .mm (a , b , out_dtype = output_dtype )
497
+ baseline = torch .mm (a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .mm (a , b )
512
498
513
- self .assertEqual (out .dtype , output_dtype )
499
+ self .assertEqual (out .dtype , output_dtype )
514
500
515
- torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
501
+ torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
516
502
517
503
518
504
@onlyCUDA
@@ -526,56 +512,51 @@ def create_inputs(B=None):
526
512
def test_addmm_baddmm_dtype_overload (self , input_dtype , M , N , K , batch_size , backend ):
527
513
device = "cuda"
528
514
dtype = input_dtype
529
- with blas_library_context (backend ):
530
- def create_inputs (B = None ):
531
- if B is None :
532
- a = torch .randn (M , K , device = device , dtype = dtype )
533
- b = torch .randn (K , N , device = device , dtype = dtype )
534
- c = torch .randn (M , N , device = device , dtype = dtype )
535
- else :
536
- a = torch .randn (B , M , K , device = device , dtype = dtype )
537
- b = torch .randn (B , K , N , device = device , dtype = dtype )
538
- c = torch .randn (B , M , N , device = device , dtype = dtype )
515
+ torch .backends .cuda .preferred_blas_library (backend )
539
516
540
- return a , b , c
517
+ def create_inputs (B = None ):
518
+ if B is None :
519
+ a = torch .randn (M , K , device = device , dtype = dtype )
520
+ b = torch .randn (K , N , device = device , dtype = dtype )
521
+ c = torch .randn (M , N , device = device , dtype = dtype )
522
+ else :
523
+ a = torch .randn (B , M , K , device = device , dtype = dtype )
524
+ b = torch .randn (B , K , N , device = device , dtype = dtype )
525
+ c = torch .randn (B , M , N , device = device , dtype = dtype )
541
526
542
- a , b , c = create_inputs ( batch_size )
527
+ return a , b , c
543
528
544
- a_fp32 , b_fp32 , c_fp32 = a . to ( torch . float32 ), b . to ( torch . float32 ), c . to ( torch . float32 )
529
+ a , b , c = create_inputs ( batch_size )
545
530
546
- output_dtypes = [ torch .float32 ]
531
+ a_fp32 , b_fp32 , c_fp32 = a . to ( torch . float32 ), b . to ( torch . float32 ), c . to ( torch .float32 )
547
532
548
- if input_dtype != torch .float32 :
549
- output_dtypes .append (input_dtype )
533
+ output_dtypes = [torch .float32 ]
550
534
551
- for output_dtype in output_dtypes :
552
- # Catch edge case of incompat with bfloat16 and major version < 8
553
- if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
554
- if output_dtype == torch .bfloat16 :
555
- continue
535
+ if input_dtype != torch .float32 :
536
+ output_dtypes .append (input_dtype )
556
537
557
- if batch_size :
558
- with self .assertRaises (RuntimeError ):
559
- torch .baddbmm (c , a , b , out_dtype = output_dtype )
560
- else :
561
- with self .assertRaises (RuntimeError ):
562
- torch .addmm (c , a , b , out_dtype = output_dtype )
538
+ for output_dtype in output_dtypes :
539
+ # Catch edge case of incompat with bfloat16 and major version < 8
540
+ if input_dtype == torch .bfloat16 and not PLATFORM_SUPPORTS_BF16 :
541
+ if output_dtype == torch .bfloat16 :
542
+ continue
543
+
544
+ if batch_size :
545
+ with self .assertRaises (RuntimeError ):
546
+ torch .baddbmm (c , a , b , out_dtype = output_dtype )
547
+ else :
548
+ with self .assertRaises (RuntimeError ):
549
+ torch .addmm (c , a , b , out_dtype = output_dtype )
550
+ else :
551
+ if batch_size :
552
+ out = torch .baddbmm (c , a , b , out_dtype = output_dtype )
553
+ baseline = torch .baddbmm (c_fp32 , a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .baddbmm (c , a , b )
563
554
else :
564
- if batch_size :
565
- out = torch .baddbmm (c , a , b , out_dtype = output_dtype )
566
- if output_dtype == torch .float32 :
567
- baseline = torch .baddbmm (c_fp32 , a_fp32 , b_fp32 )
568
- else :
569
- baseline = torch .baddbmm (c , a , b )
570
- else :
571
- out = torch .addmm (c , a , b , out_dtype = output_dtype )
572
- if output_dtype == torch .float32 :
573
- baseline = torch .addmm (c_fp32 , a_fp32 , b_fp32 )
574
- else :
575
- baseline = torch .addmm (c , a , b )
576
-
577
- self .assertEqual (out .dtype , output_dtype )
578
- torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
555
+ out = torch .addmm (c , a , b , out_dtype = output_dtype )
556
+ baseline = torch .addmm (c_fp32 , a_fp32 , b_fp32 ) if output_dtype == torch .float32 else torch .addmm (c , a , b )
557
+
558
+ self .assertEqual (out .dtype , output_dtype )
559
+ torch .testing .assert_close (out , baseline , atol = 1e-3 , rtol = 1e-3 )
579
560
580
561
581
562
@onlyCUDA
@@ -586,36 +567,35 @@ def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
586
567
M , N , K = 32 , 32 , 32
587
568
device = "cuda"
588
569
dtype = torch .float16
589
- with blas_library_context (backend ):
590
- torch .backends .cuda .preferred_blas_library (backend )
570
+ torch .backends .cuda .preferred_blas_library (backend )
591
571
592
- orig_fp16_accum = torch .backends .cuda .matmul .allow_fp16_accumulation
593
- torch .backends .cuda .matmul .allow_fp16_accumulation = True
572
+ orig_fp16_accum = torch .backends .cuda .matmul .allow_fp16_accumulation
573
+ torch .backends .cuda .matmul .allow_fp16_accumulation = True
594
574
595
- def create_inputs ():
596
- a = torch .randn (M , K , device = device , dtype = dtype )
597
- b = torch .randn (K , N , device = device , dtype = dtype )
598
- c = torch .randn (M , N , device = device , dtype = dtype )
599
- return a , b , c
575
+ def create_inputs ():
576
+ a = torch .randn (M , K , device = device , dtype = dtype )
577
+ b = torch .randn (K , N , device = device , dtype = dtype )
578
+ c = torch .randn (M , N , device = device , dtype = dtype )
579
+ return a , b , c
600
580
601
- def expand (tensor ):
602
- return tensor .unsqueeze (0 ).expand (batch_size , * tensor .shape )
581
+ def expand (tensor ):
582
+ return tensor .unsqueeze (0 ).expand (batch_size , * tensor .shape )
603
583
604
- a , b , c = create_inputs ()
584
+ a , b , c = create_inputs ()
605
585
606
- with self .assertRaises (Exception ):
607
- torch .baddbmm (expand (c ), expand (a ), expand (b ), out_dtype = torch .float32 )
586
+ with self .assertRaises (Exception ):
587
+ torch .baddbmm (expand (c ), expand (a ), expand (b ), out_dtype = torch .float32 )
608
588
609
- with self .assertRaises (Exception ):
610
- torch .addmm (c , a , b , out_dtype = torch .float32 )
589
+ with self .assertRaises (Exception ):
590
+ torch .addmm (c , a , b , out_dtype = torch .float32 )
611
591
612
- with self .assertRaises (Exception ):
613
- torch .bmm (expand (a ,), expand (b ), out_dtype = torch .float32 )
592
+ with self .assertRaises (Exception ):
593
+ torch .bmm (expand (a ,), expand (b ), out_dtype = torch .float32 )
614
594
615
- with self .assertRaises (Exception ):
616
- torch .mm (a , b , out_dtype = torch .float32 )
595
+ with self .assertRaises (Exception ):
596
+ torch .mm (a , b , out_dtype = torch .float32 )
617
597
618
- torch .backends .cuda .matmul .allow_fp16_accumulation = orig_fp16_accum
598
+ torch .backends .cuda .matmul .allow_fp16_accumulation = orig_fp16_accum
619
599
620
600
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
621
601
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
0 commit comments