-
Notifications
You must be signed in to change notification settings - Fork 24.2k
upstream apex.normalization.FusedRMSNorm
#72643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
(if it's a variant of LayerNorm, should it be supported by LayerNorm natively, e.g. if weight is not None and bias is None, it should call this new fused kernel?) |
From https://arxiv.org/abs/1910.07467
So it's the same as LayerNorm, but:
|
@vadimkantorov the problem is that you also want to ignore the average term (E[x]) there? |
It'd be awesome to have this available in PyTorch!! It was used again, in this big paper: https://arxiv.org/pdf/2112.11446.pdf |
Another big model using it: https://arxiv.org/abs/2302.13971
Also I've noticed that there seems to be a test for a fused version: https://github.com/pytorch/pytorch/blame/5dd52e250f66a5e3377eb39228cd929871f1eb5d/test/functorch/test_memory_efficient_fusion.py#L155 |
Hey!
|
But if this request is put on hold for another year it might come through. Which is fine too, since we do have |
tbh, if I could just use |
Just fyi, RMSNorm isn't just LayerNorm without bias (otherwise we could have use the functional method that allows to not pass bias). You need also to remove mean estimation. Totally fine with the plan to use |
@albanD (also |
yes
@PetrochukM this is the whole point of torch.compile (and the big difference with existing jit in ML frameworks) it is designed to work with partial graphs and small pieces! |
Thank you for clarifying that one doesn't have to compile the whole model to fuse just one component, albanD. That's great! I disagree that this should be left to users. You want pytorch to be the winning framework, correct? Make it great out of the box. Expecting users to figure out that they need to If it works why not |
I tend to agree with @stas00 as pytorch core/domain libraries is an important source of idioms that are adopted by the community. So if RmsNorm module can be implemented trivially in core by torch.compiling around simple impl, it's great to have it in core + tests + perf tests. Then this could also be a showcase/doc reference of usecase where torch.compile works great (kind of dogfooding). |
Yeah, if compiling works, why not offer out-of-the-box PyTorch modules that have been compiled together? Especially, modules that are used in state-of-the-art models. Most state-of-the-art models cannot be built in PyTorch because they require many layers that are not readily available and are hard to implement. The community is needing to use libraries like |
@albanD the problem is that we need to do a zero-to-one of having torch.compile inside regular torch library code. This is a reasonable thing to want to do, but there hasn't really been any emphasis on it (since most of the effort has been on torch.compile with full libraries). Because there is no emphasis on this style of use case, there are big perf gaps (e.g., guard evaluation overhead matters a lot more in this regime). So to make your suggestion into reality, we need to also spend some time making this work well. Or we just enlist the OSS community's help in just adding the FusedRMSNorm into the framework directly and kick the can a few more months. |
This is definitely the long term vision where compile will be always there, optimizing what it can and leaving the rest as python code. Right now we're in a weird transition period though as, as you said, this is not stable enough to be in this always-on state. And so the question is still open on how to best spend our ressources: add more fused one-off kernels that will be obsolete when compile is the default or work on making compile the default. To come back to this particular issue, adding the new flag to be able to do RMSNorm is a definite "yes we want it" but adding the fused implementation is more "we would accept a simple PR but no one on the core team is working on it". |
I think, having certain things in core that are torch.compile'd demonstrates to the users maturity of the technology, so it's a good goal to have by itself. Plus, in core you could have automated perf tests comparing it against legacy apex fused kernels. If wanted, this stuff could go into some separate experimental pytorch package (from where kernels can graduate into core). As long as it's built/released/tested along with the rest of pytorch, it's already strictly better than apex. Once pytorch core has torch.compile'd kernels, it will show true commitment to the technology |
If things are unstable, yet, it'd be wasteful to allocate resources to do such porting because an automatic solution is "imminent", then I think it's perfectly fine to close this feature request and tell the user to use |
We unify the Considering that RMSNorm offers superior efficiency compared to LayerNorm in theory, we believe that providing an official RMSNorm API would greatly benefit the community, allowing them to harness this improvement in both training and inference effectively. We also release our implementation https://github.com/ZixuanJiang/pre-rmsnorm-transformer for reference. Thanks for your consideration. |
Following up on the discussion here, from further discussion, it seems like
and Where and so
If my understanding here is correct, we would accept the addition of |
It also seems that there's a RMSNorm impl in flash-attention (fused with dropout): https://github.com/Dao-AILab/flash-attention/blob/4f285b354796fb17df8636485b9a04df3ebbb7dc/flash_attn/ops/rms_norm.py#L11 |
Am I also understanding correctly that semantically |
There is also an impl of RMSNorm in Triton at https://github.com/kakaobrain/trident/blob/main/trident/kernel/rms_norm.py - maybe can be incorporated into core? |
+1 |
Any progress on this issue ? cc: @albanD |
as a matter of fact the rmsnorm from FasterTransformer is much faster than Apex' |
Could you please share some factual information to support this claim, @vince62s? |
@stas00 I can try to make a unitary snippet but can tell you in this PR https://github.com/OpenNMT/OpenNMT-py/pull/2539/files it had a huge impact on inference tok/sec for a Mistral7B LM for instance. (separately from the other change kv_cache from flash2 that had also an impact). |
Should probably be closed now |
The PR adds the API but not the fused kernel. So we can keep this open in case someone wants to investigate if we would get benefit from a manually fused kernel. And if so, wants to upstream such implementation. |
huggingface/transformers#30236 Is it a problem with my algorithm or a problem with the RMSnorm implementation? self.weight is bf16, hidden_states is fp32, input_dtype is bf16return self.weight * hidden_states.to(input_dtype) # bf16 * bf16 # loss spike: |
This is quite interesting, maybe there needs to be some more expressive torch.mul args to allow some upcasts during the multiplication itself (compute dtype), e.g. |
It also appears that Intel IPEX has some optimized code paths for RMSNorm and LayerNorm for CPUs (mentioning for pointers for potential eventual upstream): https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/inference/python/llm-modeling |
Another fused RMSNorm impl in Triton: https://x.com/hsu_byron/status/1827072742291861975 at https://github.com/linkedin/Liger-Kernel @msaroufim The max mem improvements in Liger are massive. Curious if some of these optimizations could be added to core (like Linear layer + softmax + crossentropy. Inductor would not create such fusions now?) or if these Triton kernels could be added to PyTorch (at least as some known fusion patterns for Inductor). And of course the best would be to be able to generate these fusions with Inductor As an extension, maybe it would be nice to have some way to extend Inductor by registering into it some known fusion pattern to be matched along with a user-provided Triton code. These way, this Triton code could be contributed first e.g. in some out-of-tree package like HF or torchao and the original model code could be sometimes kept intact. And maybe then these patterns can be brought in-tree RoPE speedup is also pretty wild |
So for thi 10000 s kind of work we basically have 4 options sorted in decreasing difficulty
Honestly I'd just vote for 4 since that's easiest and useful now. 3 isn't that crazy nowadays since it's easier to package Triton kernels since they get JIT'd vs CUDA. For 1 and 2 it sounds like the right long term thing to do but would love to hear more from @Chillee or @eellison on this, how hard is it go from a solid Triton kernel to getting it working with the pattern matcher today, ideally if there's a few n00b friendly reference examples it could help us move faster here |
I think the main advantages of in-core is discoverability and resources to keep it tested for new accelerators and be notified if any regressions happen (on a large pytorch's bench) and avoiding forcing other packages to take on another dependency. Maybe if torchao can be a "staging" area before inclusion in core, but probably main packages like HF could be wary of taking a dependency on it too (or maybe not?) The problem with 4 is that it keeps the speeding-up packages fragmented and often not tested/plugging into each other. E.g. another speeding-up package announced yesterday: https://x.com/DAlistarh/status/1826538436225806468, https://github.com/IST-DASLab/Sparse-Marlin. Currently the wide adoption is happening via domination of either PyTorch core or some other ultra-popular packages like HF. At the end, IMO everyone will win if good continuously tested triton kernels somehow could be widely adopted - either by PyTorch core (eventually) or via some other ultra-popular package (I hope it was torchao, but I think for now it's just one of tons other packages experimenting with quantization APIs and kernels) So IMO lacking now a PyTorch-proper wide adoption of speed-ups mechanism besides inclusion in-core. Also, inclusion some more fused kernels / pattern matchings in core sets a new baseline for all newer packages, simplifying the benchmarking work for them. Currently there were tons of various repos with fast swiglu or fast rope (including xformers or trident or apex or oneflow). For any such repo, benchmarking against all others is not easy and error-prone, so having a better baseline in-core is very useful Regarding the bar on kernels inclusion, I propose to have it lower for kernels in popular LLM models. As the benefit from speedups from having it in-core is large. Another is question if these proposed kernels are tested or not to vanilla torch.compile versions... |
From my testing, I remember that vanilla torch.compile has on-par/better perf compared to the handwritten triton kernels for rmsnorm |
Maybe the only interesting thing there is the fused linear + softmax + loss (in terms of max mem reduction). Or another question is why it's showing such big speedups against HF, assuming HF can now use torch.compile... Or maybe not very strong baseline/config in those benchmarks? Or is the problem that torch.compile does not decompose / compile RMSNorm / applying RoPE into Triton? |
The chunked linear + crossentropy makes a big difference in terms of peak memory, and for finetuning where you're often batch-size constrained, I've seen it make a huge difference in terms of total throughput. |
So I guess, maybe what could be cool is inclusion in PyTorch core of some exteremely popular shortcuts from popular LLM models (like RoPE / SwiGLU) and having torch.compile applied to them by PyTorch itself. At least it would raise the eager baseline for any new proposed fused ops... It's also an interesting question of how to ensure that the Inductor-generated code for these useful shortcuts does not regress over time. So maybe one way is to copy-paste its generated Triton-code into core and only regenerated it once in a while? (or just torch.compile it always...) For having torch.compile calls in standard PyTorch nn library, probably this would be useful to have predictable performance, at least having some assurance that a user without an explicit wish would modify global Inductor options in the client code and impact the core-provided module): |
Unsloth also includes Triton kernels for
I wonder if this means that PyTorch native torch.compile inductor/triton codegen does not produce fast enough code for LayerNorm/RMSNorm? They patch out the Torch's impl by their own in https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/patching_utils.py#L70 Also, RMSNorm got added to ONNX: |
🚀 The feature, motivation and pitch
All T5 models and their derivatives (t5, mt5, t0, etc.) use
RMSNorm
, instead ofLayerNorm
. The former is a subset of the latter, it only scales and doesn't shift.The original need was a discovery that all HF Transformers t5-based models were somewhat slow under mixed precision, because of "manual" implementation of
T5LayerNorm
where manual up/down- casting was causing a significant bottleneck.While researching this I have run into other users who wanted to use a fast
RMSNorm
(but didn't save the references)NVIDIA/apex recently implemented
apex.normalization.FusedRMSNorm
but building apex is far from easy for a lay person.I have benchmarked it in an ensemble and it gives a pretty significant gain - about 10% improvement on the full back-to-back application. huggingface/transformers#14656 - so clearly multiple times faster on just the norm part.
So to ease user's path to faster t5-based models if possible it'd be great to have this sub-set functionality of
LayerNorm
available in pytorch.It's already in the nvfused branch: csarofeen#1428
I will see if I can find other users who may want a fast
RMSNorm
Thank you!
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345
The text was updated successfully, but these errors were encountered: