8000 upstream `apex.normalization.FusedRMSNorm` · Issue #72643 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
upstream apex.normalization.FusedRMSNorm #72643
@stas00

Description

@stas00

🚀 The feature, motivation and pitch

All T5 models and their derivatives (t5, mt5, t0, etc.) use RMSNorm, instead of LayerNorm. The former is a subset of the latter, it only scales and doesn't shift.

The original need was a discovery that all HF Transformers t5-based models were somewhat slow under mixed precision, because of "manual" implementation of T5LayerNorm where manual up/down- casting was causing a significant bottleneck.

While researching this I have run into other users who wanted to use a fast RMSNorm (but didn't save the references)

NVIDIA/apex recently implemented apex.normalization.FusedRMSNorm but building apex is far from easy for a lay person.

I have benchmarked it in an ensemble and it gives a pretty significant gain - about 10% improvement on the full back-to-back application. huggingface/transformers#14656 - so clearly multiple times faster on just the norm part.

So to ease user's path to faster t5-based models if possible it'd be great to have this sub-set functionality of LayerNorm available in pytorch.

It's already in the nvfused branch: csarofeen#1428

I will see if I can find other users who may want a fast RMSNorm

Thank you!

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: nnRelated to torch.nnmodule: norms and normalizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    In Progress

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0