-
Notifications
You must be signed in to change notification settings - Fork 24.2k
DTensor support for fused qkv matmul #140069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Depends on DTensor Strided Sharding: |
cc: @XilunWu |
@HDCharles sorry for catching this issue late. As @ad8e said, this can be done via Strided Sharding feature in DTensor. I'll try the TP code you shared in DTensor and see if there's any implementation gap. Thanks for submitting the feature request! |
@XilunWu Hey, wondering any progress on this feature? Thank you! |
Hi @yzhangcs sorry currently we haven't made a plan to support this but I'll notify if we have more updates. |
@yzhangcs For fused QKV, indeed you might need strided sharding as a way to express things. It is likely that for your use case, some op support related to |
Sorry I think this might be because I mentioned strided sharding could support fused QKV before, which might led the discussions to strided sharding. After looking into this more, I think strided sharding and fused QKV sharding are two orthogonal problems. Strided sharding only describes the sharding order when sharding one tensor on multiple device mesh dimensions. Fused QKV sharding is different, it is trying to shard a combined linear (one big tensor) but as if it shards there linears separately on one device mesh dimension. So essentially it tries to shard three tensor into one big sharded tensor. I think this is not a regular sharded tensor but one might want to treat it as a regular sharded tensor during runtime. More over the three tensors might have different shapes which need to be traced manually. Given this is not a regular sharded tensor, we should not try to force it into a regular one during sharding initialization, instead we should treat it faithfully as shard three tensors to one big sharded tensor, and just treat it as a single big DTensor during runtime. We can easily implement the fused QKV feature in the TP layer, i.e. sth like below should work:
cc @HDCharles @ad8e |
My understanding of the proposed scheme is: the most representative shape of |
@ad8e Yep that is right! Fused QKV sharding really need to align between pretraining -> finetune or inference stage (normal sharding does not need to), i.e. if pretrain use fused qkv sharding, then finetune/inference requires no change. But if pretrain use separate qkv sharding, there needs to be a surgery converting the sharding, as fused qkv trying to treat the whole qkv as one weight, while non-fused one keeping as three weights separately. From checkpointable states prospective they are different |
This PR adds fused QKV sharding in the TP layer. There should be no "strided" sharding involved as fused QKV linear layer is more about combining three layers into one. See design and discussions: #140069 (comment) resolves #140069
This PR adds fused QKV sharding in the TP layer. There should be no "strided" sharding involved as fused QKV linear layer is more about combining three layers into one. See design and discussions: #140069 (comment) resolves #140069
🚀 The feature, motivation and pitch
For transformer architecture (for example https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L195-L211) it tends to be most performant to merge the qkv matrices together. If you try to shard this concatenated tensor then the subsequent SDPA op won't be shared correctly since you need each column of q sharded with the corresponding columns of k and v [q1,k1,v1,...], but by default the sharding will be [q1, q2, q3...] When not using DTensor this is relatively easy to get to work: https://github.com/pytorch-labs/gpt-fast/blob/main/tp.py#L73
but for DTensor the way to enable this is really unclear. is there a way to handle this type of operation with DTensor parallelization or should we just stick to normal tensor parallel support and figure out how to get it to work with our APIs?
This is currently blocking tensor parallel support in torchAO so i wanted to centralize discussion to a single location.
Alternatives
don't use DTensor for tensor parallel
Additional context
No response
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu
The text was updated successfully, but these errors were encountered: