8000 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests by eqy · Pull Request #153655 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests #153655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 129 additions & 111 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@
assert torch.get_default_dtype() is torch.float32


@contextlib.contextmanager
def blas_library_context(backend):
prev_backend = torch.backends.cuda.preferred_blas_library()
torch.backends.cuda.preferred_blas_library(backend)
try:
yield
finally:
torch.backends.cuda.preferred_blas_library(prev_backend)

class TestMatmulCuda(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -141,8 +150,10 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool =
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
@dtypes(torch.float16, torch.bfloat16, torch.float32)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False)

@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
Expand All @@ -151,31 +162,31 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype):
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, True)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, True)

@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
@dtypes(torch.float16)
# m == 4 chooses OUTPUT_TYPE reduction on H200
# m == 8 chooses OUTOUT_TYPE reduction on A100
# m == 8 chooses OUTPUT_TYPE reduction on A100
@parametrize("small_size", [4, 8])
@parametrize("size", [32768])
@parametrize("backend", ["cublaslt", "cublas"])
def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dtype: torch.dtype, backend):
# TODO(eqy): replace with contextlib once that is merged
orig = torch.backends.cuda.preferred_blas_library()
torch.backends.cuda.preferred_blas_library(backend)
orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda')
m2 = torch.ones((size, small_size), dtype=dtype, device='cuda')
m2[size // 2:, :] = -1.0
b = torch.zeros((small_size,), dtype=dtype, device='cuda')
out = torch.addmm(b, m1, m2, beta=1.0)
self.assertEqual(out.sum().item(), 0.0)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision
torch.backends.cuda.preferred_blas_library(orig)
with blas_library_context(backend):
torch.backends.cuda.preferred_blas_library(backend)
orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda')
m2 = torch.ones((size, small_size), dtype=dtype, device='cuda')
m2[size // 2:, :] = -1.0
b = torch.zeros((small_size,), dtype=dtype, device='cuda')
out = torch.addmm(b, m1, m2, beta=1.0)
self.assertEqual(out.sum().item(), 0.0)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision

@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
Expand All @@ -184,8 +195,10 @@ def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dty
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False, True)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False, True)

@onlyCUDA
@skipIfRocm
Expand Down Expand Up @@ -479,49 +492,48 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)

def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
return a, b
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
return a, b

a, b = create_inputs(batch_size)
a, b = create_inputs(batch_size)

a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)

output_dtypes = [torch.float32]
output_dtypes = [torch.float32]

if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)

for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue

if batch_size:
with self.assertRaises(RuntimeError):
torch.bmm(a, b, out_dtype=output_dtype)
if batch_size:
with self.assertRaises(RuntimeError):
torch.bmm(a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.mm(a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.mm(a, b, out_dtype=output_dtype)
else:
if batch_size:
out = torch.bmm(a, b, out_dtype=output_dtype)
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
else:
out = torch.mm(a, b, out_dtype=output_dtype)
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
if batch_size:
out = torch.bmm(a, b, out_dtype=output_dtype)
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
else:
out = torch.mm(a, b, out_dtype=output_dtype)
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)

self.assertEqual(out.dtype, output_dtype)
self.assertEqual(out.dtype, output_dtype)

torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)


@onlyCUDA
Expand All @@ -535,51 +547,56 @@ def create_inputs(B=None):
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)

def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
c = torch.randn(B, M, N, device=device, dtype=dtype)
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
c = torch.randn(B, M, N, device=device, dtype=dtype)

return a, b, c
return a, b, c

a, b, c = create_inputs(batch_size)
a, b, c = create_inputs(batch_size)

a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)

output_dtypes = [torch.float32]
output_dtypes = [torch.float32]

if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)

for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue

if batch_size:
with self.assertRaises(RuntimeError):
torch.baddbmm(c, a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.addmm(c, a, b, out_dtype=output_dtype)
else:
if batch_size:
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
if batch_size:
with self.assertRaises(RuntimeError):
torch.baddbmm(c, a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.addmm(c, a, b, out_dtype=output_dtype)
else:
out = torch.addmm(c, a, b, out_dtype=output_dtype)
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)

self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
if batch_size:
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
if output_dtype == torch.float32:
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.baddbmm(c, a, b)
else:
out = torch.addmm(c, a, b, out_dtype=output_dtype)
if output_dtype == torch.float32:
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.addmm(c, a, b)

self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)


@onlyCUDA
Expand All @@ -590,35 +607,36 @@ def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend):
M, N, K = 32, 32, 32
device = "cuda"
dtype = torch.float16
torch.backends.cuda.preferred_blas_library(backend)
with blas_library_context(backend):
torch.backends.cuda.preferred_blas_library(backend)

orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True

def create_inputs():
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
return a, b, c
def create_inputs():
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
return a, b, c

def expand(tensor):
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
def expand(tensor):
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)

a, b, c = create_inputs()
a, b, c = create_inputs()

with self.assertRaises(Exception):
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)

with self.assertRaises(Exception):
torch.addmm(c, a, b, out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.addmm(c, a, b, out_dtype=torch.float32)

with self.assertRaises(Exception):
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)

with self.assertRaises(Exception):
torch.mm(a, b, out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.mm(a, b, out_dtype=torch.float32)

torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum

f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
Expand Down
Loading
0