8000 error out on negative offs or on K=0 in group gemm by ngimel · Pull Request #153226 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

error out on negative offs or on K=0 in group gemm #153226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

ngimel
Copy link
Collaborator
@ngimel ngimel commented May 9, 2025

Error out if K=0 in one of the grouped gemms to avoid hangs in #152668
Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already)

One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with

  File "/data/users/ngimel/pytorch/torch/_inductor/graph.py", line 1246, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/data/users/ngimel/pytorch/torch/_inductor/lowering.py", line 445, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 444, in tuned_scaled_grouped_mm
    if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 375, in can_use_triton_kernel
    offs is not None
  File "/home/ngimel/.conda/envs/pytorch_monarch/lib/python3.10/site-packages/sympy/core/relational.py", line 516, in __bool__
    raise TypeError("cannot determine truth value of Relational")
torch._inductor.exc.InductorError: LoweringException: TypeError: cannot determine truth value of Relational

which is weird, there's no relational that sympy has to evaluate in offs is not None, and when running this test separately (test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for.
Edit: that error is "fixed" by setting dynamic=False, now with correct meat function something's wrong with dynamic shapes.

cc @ptrblck @msaroufim @eqy @jerryzh168

@ngimel ngimel requested review from eqy and syed-ahmed as code owners May 9, 2025 01:43
Copy link
pytorch-bot bot commented May 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153226

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 48df5d4 with merge base 9d00f2b (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ngimel ngimel added module: cuda Related to torch.cuda, and CUDA support in general release notes: cuda release notes category and removed module: cuda Related to torch.cuda, and CUDA support in general labels May 9, 2025
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32")

# Check matrix sizes
torch._check(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should some of these be torch._check_value or torch._check_assert?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using the same macros as are used in _scaled_mm meta, none of them are check_assert.

delta % align == 0 &&
"expected dynamic dimension byte size to be multiple of 16 \n");
delta >=0 && delta % align == 0 &&
"expected dynamic dimension byte size to be non-negative multiple of 16 \n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually print the error message or just && a string with the condition :/?
CUDA_KERNEL_ASSERT_MSG seems to be the macro for that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA_KERNEL_ASSERT_MSG is a sugar over this

self.assertEqual(bgrad, b.grad)
if agrad is not None:
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)

@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this did not fire as a type error:

image

I think the expressions in

and b_stride[0] == (b_shape[1] * b_shape[2])
should be using the statically_known variants from sizevars

@huydhn
Copy link
Contributor
huydhn commented May 9, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased ngimel/zero_k_grouped_gemm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout ngimel/zero_k_grouped_gemm && git pull --rebase)

Copy link
Contributor
@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the warning!

@ngimel
Copy link
Collaborator Author
ngimel commented May 10, 2025

@pytorchbot merge -f "test failure unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0