8000 🔴🔴🔴 [`Attention`] Refactor Attention Interface for Bart-based Models by vasqu · Pull Request #38108 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models #38108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 71 commits into from
May 22, 2025

Conversation

vasqu
Copy link
Contributor
@vasqu vasqu commented May 13, 2025

This PR is gonna tackle two things in general:

  • Flex Attention for all base attention types (encoder, decoder, encoder-decoder cross)
  • New Attention Interface for a bunch of models (mostly based on Bart's implementation)
  • As a bonus, some models have already been refactored into modular if seen fit, e.g. PL Bart.

Affected models (will be updated when I have enough time) - probably not 100% accurate

  • Bart
  • Biogpt
  • Mbart
  • Bigbird Pegasus
  • Blenderbot
  • Blenderbot Small
  • Wav2vec
  • Data2vec audio
  • Informer
  • Timeseries Transformer
  • Hubert
  • M2M 100
  • Marian
  • Musicgen
  • Music Melody
  • Nllb MoE
  • Patchts Mixer
  • Patchtst
  • Pegasus
  • Pegasus X
  • PL Bart
  • Sew
  • Speech to Text
  • Uni Speech
  • Uni Speech SAT

Possibly doable in this PR:

  • Rename flex attn mask creation since the focus is currently only on decoder-only ones, the naming is not really fitting

Worth a discussion (?):

  • Time Series Transformer and Speech to Text theoretically have the new attentions but the tests are unsuitable (as they generate inputs with ids). Imo, a rewrite/adjustment is too much effort. I disabled it for now but we could also enable them and give the users a warning (untested).
  • Pegasus X has more flaky logits on sdpa - should it still be enabled?
  • Nllb MoE has issues with flash attention as masks are prepared differently and used by the MoE leading to several issues + sdpa has also more flaky logits --> I think it's not worth it based on the size of the model + the usage.

Worth a discussion TL;DR:

  • Timer Series Transformer + Speech to Text do not work with the current testing framework (need special attention to what they get as input)
  • Pegasus X has more flaky logits on sdpa - should it still be enabled?
  • Nllb MoE is too complicated to enable fastly for Flash Attenti 8000 on (attention mask <-> MoE interaction) + same as pegasus x on sdpa logits.

Future PRs will address:

  • Other models such as Bert-based models.
  • Other models such as Whisper-based models.
  • More modular.
  • Proper kwargs passing --> multiple attn types should get different kwargs?
  • In combination with above point ^ optionally tp plans etc for attn backend support?
  • Fa might be able to work without position ids but fa_kwargs only - this is a current limitation in the fa modeling utils...
  • Mask refactor as iirc the prepare_for... are not really neat. However, the way it is written, refactoring should be easy ;)
  • Flex Attn tests and fixes (compile issues)
  • Better datasets versioning? Pipeline issues esp with audio

@github-actions github-actions bot marked this pull request as draft May 13, 2025 15:10
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu changed the title [Attention] Refactor Attention Interface and Enable Flex Attention [Attention] Refactor Attention Interface for Bart-based Modelsand Enable Flex Attention May 14, 2025
@vasqu vasqu changed the title [Attention] Refactor Attention Interface for Bart-based Modelsand Enable Flex Attention [Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention May 14, 2025
Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 for the abstraction of the attention interface!
Let's keep the core logic explicit, and just put each warning in each integration/sdpa|flash|flex.py file!

attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no here we want one explicit thing, that the default attention is eager_attention_forward!
The core philosophy when we abstract is we keep the inderections to a minimal, and thus the core logic (eager attention) should be explicit!
But very nice to put the rest (warning and etc) inside each function, but not in ALL_ATTENTION_FUNCTIONS's call!

Thus each sdpa, flex or flash have their own warning, we should not abstract!

attention_mask=attention_mask,
training=self.training,
dropout=self.dropout,
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = eager_attn_forward
if self.config._attn_implementation != "eager":
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

let's make it explicit what is the default!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're too fast :D will change in a sec, wouldnt work either way with attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] since eager is not registered in the interface

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha yep exactly

@vasqu vasqu changed the title [Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention 🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention May 21, 2025
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device, dtype=torch.float16)
# Flex Attention can not use dropout
if hasattr(config, "attention_droput"):
Copy link
Contributor Author
@vasqu vasqu May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to fix this, typo (discovered during whisper)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rerunning flex tests to see which models will fail.

Copy link
Contributor Author
@vasqu vasqu May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a lot of models fail on flex attention... Not sure if I should submit a torch issue. Disabled flex on them for now - I don't think it's a high priority atm.

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos to you! Let's break a bit and warn without fallback WDYT? not adding a new arg!

"Falling back to eager attention because `flash_attention_2` does not support"
" `output_attentions=True` or `head_mask`."
)
return eager_fallback(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if eager fallback is none we are failing hard! Let's maybe just emit the warning no? And return None output attentions

"Falling back to eager attention because `flex_attention` does not support"
" `output_attentions=True`, `head_mask`, or `dropout`."
)
return eager_fallback(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here! And head mask is something we kinda deprecated so let's just return None prob

@vasqu vasqu changed the title 🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models and Enable Flex Attention 🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models May 22, 2025
@vasqu
Copy link
Contributor Author
vasqu commented May 22, 2025

Another day, another merge conflict

@vasqu
Copy link
Contributor Author
vasqu commented May 22, 2025

Checked offline again with Arthur, merging

@vasqu vasqu merged commit d95c864 into main May 22, 2025
21 checks passed
@vasqu vasqu deleted the vas-enc-dec-attn-refactor branch May 22, 2025 15:13
redmoe-moutain pushed a commit to redmoe-moutain/transformers that referenced this pull request Jun 10, 2025
…uggingface#38108)

* starting attn refactor for encoder decoder models via bart (eager + sdpa)

* flash attention works, remove unnecessary code

* flex attention support for bart!, gotta check if the renaming is not too aggressive

* some comments

* skip flex grad test for standalone as done with the other test

* revert flex attn rename (for now), sdpa simplify, and todos

* more todos

* refactor mask creation for reuse

* modular attempt at biogpt

* first batch of other models

* fix attn dropout

* fix autoformer copies

* hubert

* another batch of models

* copies/style + last round of bart models --> whisper next?

* remove unnecessary _reshape function and remove copy to whisper

* add skip for decoder-only models out of enc-dec (same as in bart)

* bring back licences

* remove comment, added to pr read instead

* mostly docs

* disable sew flex attn as it's unclear attn mask for now

* oops

* test fixes for enc-dec

* torch fx fixes + try at flex attn

* skip on mbart

* some more fixes

* musicgen skip / delete old attn class logic + sdpa compose compile skip

* disable flex attn for musicgen, not worth the effort

* more fixes and style

* flex attention test for dropout and encoder decoder that dont have main input names

* informer fixes

* the weirdest thing I've encountered yet...

* style

* remove empty tensor attempt, found core root in previous commits

* disable time series due to tests being very text centric on inputs

* add speech to text to be ignoring the other attns, also due to tests

* update docs

* remaining issues resolved ?

* update docs for current state --> nllb moe and pegasus x sdpa is questionable :D

* some models have not set the is_causal flag...

* change dtype in softmax tol old behaviour + some modular fixes

* I hate it but it is what it is

* fixes from main for bart

* forgot this one

* some model fixes

* style

* current status

* marian works now

* fixing some copies

* some copy fixes + time series x informer

* last models possibly and fixes on style/copies

* some post merge fixes

* more fixes

* make attention interface callable and move warnings there

* style lol

* add comment to "unsupported"

* remove callable interface and change interface warnings + some copies

* fix

* ternary is ugly af, make it simpler

* how did that happen

* fix flex attn test

* failing the test

* no more fallback! fixing copies next

* style + attn fixed

* fixing copies and mask creation

* wrong copy

* fixup tests and disable flex attn for now

* fixup last tests?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0