-
Notifications
You must be signed in to change notification settings - Fork 24.2k
error out on negative offs or on K=0 in group gemm #153226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153226
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 48df5d4 with merge base 9d00f2b ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch._check(offs.dtype == torch.int, lambda: "Offsets have to be int32") | ||
|
||
# Check matrix sizes | ||
torch._check( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should some of these be torch._check_value or torch._check_assert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using the same macros as are used in _scaled_mm meta, none of them are check_assert.
delta % align == 0 && | ||
"expected dynamic dimension byte size to be multiple of 16 \n"); | ||
delta >=0 && delta % align == 0 && | ||
"expected dynamic dimension byte size to be non-negative multiple of 16 \n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually print the error message or just &&
a string with the condition :/?
CUDA_KERNEL_ASSERT_MSG
seems to be the macro for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA_KERNEL_ASSERT_MSG is a sugar over this
self.assertEqual(bgrad, b.grad) | ||
if agrad is not None: | ||
self.assertEqual(agrad, a.grad) | ||
self.assertEqual(bgrad, b.grad) | ||
|
||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") | ||
@xfailIfSM100OrLater |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised this did not fire as a type error:
I think the expressions in
and b_stride[0] == (b_shape[1] * b_shape[2]) |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
dfa53a0
to
48df5d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the warning!
@pytorchbot merge -f "test failure unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Error out if K=0 in one of the grouped gemms to avoid hangs in #152668
Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already)
One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with
which is weird, there's no relational that sympy has to evaluate in
offs is not None
, and when running this test separately (test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda
) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for.Edit: that error is "fixed" by setting
dynamic=False
, now with correct meat function something's wrong with dynamic shapes.cc @ptrblck @msaroufim @eqy @jerryzh168