-
Notifications
You must be signed in to change notification settings - Fork 710
Add more GPU architectures support #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for this, as we don't have SM100 GPUs, so let's make it a draft and wait for community feedback :) |
I'll help verify. Stay tuned. |
CMakeLists.txt
Outdated
|
||
set(CMAKE_CXX_STANDARD 20) | ||
set(CMAKE_CUDA_STANDARD 20) | ||
set(CMAKE_CXX_STANDARD 17) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change back to 17 from 20?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a conflict after rebasing. I'll resolve this later.
The Blackwell part looks good to me. I believe the remaining work involves making it compatible with Hopper. Thanks for this great work! |
deep_gemm/utils/layout.py
Outdated
return sf | ||
|
||
|
||
def transform_sf_into_required_layout(sf: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be great to allow skipping (INT, 128, 128) case where the data is already pre-transformed by other fused kernels. For example, my patch:
should_skip_transform = (
(sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == '100a')
or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == '100a')
)
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
...
if should_skip_transform:
return check_sf_layout(sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True,
type_check=torch.int)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: please refer to sgl-project#1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. I'm thinking about how to skip transform in the most appropriate way, but for now, let me merge your patch as a workaround.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there may be better methods, my patch is just quickly hack to ensure SGLang can work
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, | ||
ceil_div, | ||
set_num_sms, get_num_sms, | ||
get_col_major_tma_aligned_tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_col_major_tma_aligned_tensor
is this function deprecated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if we don't import the function here, the users have to refactor the code. Eg:
from deep_gemm import calc_diff, ceil_div
->
from deep_gemm.testing.numeric import calc_diff
from deep_gemm.utils.math import ceil_div
Is this expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_col_major_tma_aligned_tensor is this function deprecated?
Not used in SM100, will be added to utils/layout.py in SM90 support.
Is this expected?
Yes, I think they are not main functions of DeepGEMM, should not be exposed at the top level. And these two functions are simple, I don't think it will have much impact on users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense to me, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RayWang96 When you said not used in SM100, does this mean it will be automatically handled in SM100 and we don't need the code like https://github.com/vllm-project/vllm/blob/dac8cc49f43f7d2639d873532a408949169821a9/vllm/model_executor/layers/quantization/fp8.py#L654?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8_gemm_nt
can implicitly handle the transformation, or this step can be skipped to only validate the layout. It's recommended to use a previous kernel to transform to the specified layout first, then call GEMM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So do you mean it is recommended that we call this get_col_major_tma_aligned_tensor
each time before we call any gemm operation?
|
||
setuptools.setup( | ||
name='deep_gemm', | ||
version='1.0.0' + revision, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for large API changes, we can bump version to 2.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! We will sync up with the DeepGEMM team to decide on the version number.
@RayWang96 Hi, will this pr be merged soon? I think my integration has been merged to the vllm main by accident vllm-project/vllm#19820. But if we merge this pr now, I think it would be fine. |
@yewentao256 Sorry, not now. We still needs 2-3 weeks to merge with the Hopper code. |
okk, thanks for letting me know |
vllm-project/vllm#20090 This unit test can pass on H100. But for B200 of your integration, it can not pass, so perhaps you can use this to validate the accuracy as well. Even if we use the exact same tensor input, and just stop at the first |
@yewentao256 I read and tried to debug the unit test of VLLM, and I think the problem is that the scale is not being converted to E8M0, DeepGEMM Blackwell currently supports only E8M0 scale. The unit test is also comparing the precision of two different setups: Triton FP8 using FP32 scales vs DeepGEMM using E8M0 scales. As a result, the assert_close test fails, I think it's expected. We should benchmark Triton FP8 and DeepGEMM FP8 against Triton BF16. In that scenario, I found their precision loss was very similar. |
So the correctness of DeepGemm narrow down to a smaller scope? I am thinking since this is supported on H100, to make sure the correctness for all of the models that used DeepGemm, it should still pass on B200. Actually this makes all of the unit test corresponding to DG on vllm failed |
First, if you directly pass an FP32 scale into DeepGEMM Blackwell, it will first be converted to an integer and then transformed into the format required by the input (see here). Therefore, this is actually incorrect—the scale needs to be pre-converted to a power of two (see here). This might be one reason why the unit test failed. On this generation of GPUs, the scale must use E8M0 to fully utilize the computing power (refer to this). In our tests, we found that this does not have much impact on the accuracy of the DeepSeek model. Therefore, using the E8M0 scale will be our most recommended approach. Another reason why unit test failed might be that, if GEMM with FP32 scales is used as the baseline, then the results of GEMM with E8M0 scales will not align numerically. Therefore, I think the correct way to write the unit test is to compare it with the BF16 GEMM results, rather than with the FP32 scale GEMM results. |
@RayWang96 Thanks for your reply!
|
|
||
def ceil_to_ue8m0(x: torch.Tensor): | ||
assert x.view(-1).amax().item() > 0 | ||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here .item()
will cause an cuda graph issue.
File "/home/wentao/.wentao_env/lib/python3.12/site-packages/deep_gemm-1.1.0-py3.12.egg/deep_gemm/utils/math.py", line 33, in per_token_cast_to_fp8
sf = ceil_to_ue8m0(x_amax / 448.0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wentao/.wentao_env/lib/python3.12/site-packages/deep_gemm-1.1.0-py3.12.egg/deep_gemm/utils/math.py", line 24, in ceil_to_ue8m0
assert x.view(-1).amax().item() > 0
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wentao/.wentao_env/lib/python3.12/site-packages/torch/utils/_device.py", line 104, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Perhaps deleting the assert statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this function is primarily intended for testing purposes. For performance optimization, a custom implementation should be considered, potentially fused into other kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting me know
tests/test_core.py
Outdated
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') | ||
# noinspection PyShadowingNames | ||
def test_func(): | ||
deep_gemm.fp8_gemm_nt(a, b, d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be deep_gemm.fp8_gemm_nt(a, b, d, c=c)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks. We did fix this in our internal version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any other changes in the internal version, especially for the interfaces?
Currently we are actively working on integration on this version in vllm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly performance optimization, and will not break old interfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great, thanks!
Hi @RayWang96 export VLLM_ALL2ALL_BACKEND="deepep_high_throughput"
VLLM_USE_DEEP_GEMM=1 lm_eval --model vllm --model_args "pretrained=deepseek-ai/DeepSeek-R1,data_parallel_size=8,gpu_memory_utilization=0.95,max_model_len=16384,enable_expert_parallel=True" --tasks gsm8k --batch_size auto --num_fewshot 5
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± | 0.006|
| | |strict-match | 5|exact_match|↑ |0.9492|± | 0.006|
# Nodeepgemm
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9530|± |0.0058|
| | |strict-match | 5|exact_match|↑ |0.9522|± |0.0059| Another question is about E2E accuracy, this is the result we have on vllm, does this match your result internally? |
61f65d6
to
8dfa329
Compare
We've added Hopper support, it's ready to be merged into main branch, @LyricZhao. |
Anonymous cloning is more difficult with git@ (it requires per user SSH keys), which makes CI automation more difficult. This commit reverts to using https:// as it was before deepseek-ai#112
Anonymous cloning is more difficult with git@ (it requires per user SSH keys), which makes CI automation more difficult. This commit reverts to using https:// as it was before #112
Hi DeepGEMM Team,
This PR is submitted by NVIDIA and introduces support for more GPU architectures.
Key Points
On-going and coming soon
Long-term optimization
We are actively working on these optimizations. It aims to continuously provide support for NVIDIA's latest GPU hardware in DeepGEMM, ensuring the community can perform efficient and scalable training and inference with DeepSeek models for next-generation AI workloads.
Thanks!