-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Enable fp8 rowwise scaling kernel on cuda, TAKE 2: #125204 #128989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128989
Note: Links to docs will display an error until the docs builds have been completed. ❌ 25 New Failures, 2 Unrelated FailuresAs of commit 4c6fe75 with merge base 9a7e251 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e6d341a
to
252c489
Compare
252c489
to
4c6fe75
Compare
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR seems to break FBGEMM runtime: https://github.com/pytorch/benchmark/actions/runs/9704891003/job/26785961181 Error message:
|
@xuzhao9 this is interesting and I imagine it has to do with: https://github.com/pytorch/pytorch/pull/128989/files#diff-ac10d47f44dcf2fc2ec547d3dcdf796dea0498b4c3461e820152afb4cbdfae75R16-R58 |
@xuzhao9 following fixes the linker error for PyTorch, due to cutlass using driver API: pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.cu Lines 14 to 44 in 0ffb175
Perhaps fbgemm needs to add something similar? But we can not link PyTorch with libcuda |
Summary
First PR got reverted and needed a redo
This pull request introduces an fp8 row-scaling kernel as an optional implementation for
scaled_mm
. The kernel selection is based on the scaling tensors of the inputs. For inputsx
andy
of shape[M, K]
and[K, N]
respectively, the following conditions must be met:x
's scale should be a 1-dimensional tensor of lengthM
.y
's scale should be a 1-dimensional tensor of lengthN
.It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for
y
are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".The following two PRs were required to enable local builds:
Todo
We still do not build our Python wheels with this architecture.
@ptrblck @malfet, should we replace
sm_90
withsm_90a
?The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
ifdef
I tried to use :
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900
to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do thisKernel Credit:
@jwfromm
cc @yanbing-j @vkuzo @albanD @kadeng