8000 Refactor `MambaCache` to `modeling_mamba.py` by manueldeprada · Pull Request #38086 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

Refactor MambaCache to modeling_mamba.py #38086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 78 commits into from
Jul 21, 2025

Conversation

manueldeprada
Copy link
Contributor
@manueldeprada manueldeprada commented May 12, 2025

This PR moves the specialized MambaCache class from cache_utils.py to src/transformers/models/mamba/modeling_mamba.py. This is preliminary work for #38077.

Changes:

  • Moved MambaCache to its own file, aligning with Zamba, Bamba, etc.
  • Removed unnecessary Mamba-specific code from generate.
  • Added Mamba cache init to preparing_inits_from_generation from forward(). See this comment.
    • Why? Bamba, Jamba, GraniteMoeHybrid, Zamba, and Zamba2 had settled into initializing custom caches in preparing_inits_from_generation and only Mamba, Mamba2, and FalconMamba were doing it in forward, which is bad for torch.compile.
  • We dont break BC with any import (thanks joao for the idea!)
  • Cleaned up some Mamba and FalconMamba slow tests, which had been failing on main for a long time.
  • Removed DDP Mamba tests. I had a DDP implementation for Mamba in 66b7162 so tests passed but removed it since DDP is not needed in Mamba, as per Joaos's instructions.

@github-actions github-actions bot marked this pull request as draft May 12, 2025 15:14
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@manueldeprada manueldeprada marked this pull request as ready for review May 12, 2025 15:30
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@manueldeprada manueldeprada requested a review from gante May 12, 2025 15:45
Copy link
Member
@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, added a few tiny nits 👍

@gante
Copy link
Member
gante commented May 13, 2025

@manueldeprada have you run slow mamba (and falcon mamba) tests to ensure there are no regressions?

(after we confirm slow tests are okay, let's tag arthur)

@manueldeprada
Copy link
Contributor Author
manueldeprada commented May 14, 2025

@manueldeprada have you run slow mamba (and falcon mamba) tests to ensure there are no regressions?

(after we confirm slow tests are okay, let's tag arthur)

Ran the tests locally, there are some failures that already happened before the PR:

FAILED tests/models/mamba/test_modeling_mamba.py::MambaModelTest::test_multi_gpu_data_parallel_forward - TypeError: 'MambaCache' object is not iterable
FAILED tests/models/mamba/test_modeling_mamba.py::MambaIntegrationTests::test_compile_mamba_cache - AssertionError: Attempt to trace forbidden callable <function mark_static_address at 0x7fd585001440>
---
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaModelTest::test_multi_gpu_data_parallel_forward - TypeError: 'MambaCache' object is not iterable
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaIntegrationTests::test_batched_generation - AssertionError: Lists differ: ['Hello today I am going to be talking abo[148 chars]The"] != ["Hello today I'm going to show you how to[149 chars].\n']
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaIntegrationTests::test_generation_4bit - AssertionError: "Hello today Iava,\n\nI'm sorry to hear that you're having trouble with the " != 'Hello today I\'m going to talk about the "C" in the "C-I-'
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaIntegrationTests::test_generation_bf16 - AssertionError: "Hello today Iava,\n\nI'm sorry to hear t[31 chars]the " != 'Hello today I am going to show you how t[47 chars]Step'
FAILED tests/models/falcon_mamba/test_modeling_falcon_mamba.py::FalconMambaIntegrationTests::test_generation_torch_compile - AssertionError: "Hello today Iava,\n\nI'm sorry to hear t[31 chars]the " != 'Hello today I am going to show you how t[47 chars]Step'

will run in CI once I have permissions.

@huggingface huggingface deleted a comment from github-actions bot May 14, 2025
@manueldeprada
Copy link
Contributor Author

Same failures on CI than locally @gante . Let me know if those slow tests should be fixed, ignored or deleted.

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! There are still quite a few outstanding issue notably the imports that are not at the top

if use_cache:
if cache_params_not_initialized:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if cache_params_not_initialized:
if use_cache and cache_params_not_initialized:

Copy link
Contributor Author
@manueldeprada manueldeprada Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is keeping BC with respect to how generate() did things before, now that we are moving MambaCache out of generate. I just rewrote the same behaviour without a new variable. See this comment: #38086 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry its strictly the same given the indentation no?

@manueldeprada manueldeprada changed the title Refactor MambaCache to modeling_mamba.py (parity with Zamba) Refactor MambaCache to modeling_mamba.py Jul 16, 2025
@@ -651,6 +765,8 @@ def prepare_inputs_for_generation(
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`

if use_cache and cache_params is None:
Copy link
Contributor Author
@manueldeprada manueldeprada Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this 2 new additions are needed for BC and affect all classes that use MambaCache: it emulates the order in which generate() initialized MambaCache. See here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for proposing the pragma! I think we can try not renaming functions/import sources at all, I don't think any model does that today!

selective_scan_fn,
causal_conv1d_fn,
causal_conv1d_update,
mamba_inner_fn, # modular: no_replace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (IF) we have the pragma, should not appear here

@@ -141,7 +141,12 @@ def leave_Name(self, original_node, updated_node):
return updated_node

def leave_ImportFrom(self, original_node, updated_node):
"""The imports from other file types (configuration, processing etc) should use original model name."""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping rename for absolute imports does not affect other models and is probably a reasonable assumption.

Relative imports need to be renamed: from .configutation_mamba import ...MambaConfig

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small things left!

@@ -651,6 +765,8 @@ def prepare_inputs_for_generation(
):
# Overwritten -- uses `cache_params` as opposed to `past_key_values`

if use_cache and cache_params is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: falcon_mamba, mamba, mamba2

@manueldeprada
Copy link
Contributor Author
manueldeprada commented Jul 21, 2025

Thanks, @ArthurZucker! I ended up refactoring the entire prepare_inputs_for_generation method to make it clearer, rather than just making minimal changes to get the tests passing. I’ll do this from the start going forward!

lmk if theres anything left!

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that looks better 😉

@manueldeprada manueldeprada merged commit 1aa7256 into huggingface:main Jul 21, 2025
25 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 22, 2025
* Refactor MambaCache to modeling_mamba.py (parity with Zamba)

* ruff

* fix dummies

* update

* update

* remove mamba ref in cache tests

* remove cache_implementation from tests

* update

* ruff

* ruff

* sneaky regression

* model consistency

* fix test_multi_gpu_data_parallel_forward

* fix falcon slow tests

* ruff

* ruff

* add sample false

* try to fix slow tests

* Revert "fix test_multi_gpu_data_parallel_forward"

This reverts commit 66b7162.

* fix tests on nvidia t4, remove dataparallel tests from mamba

* ruff

* remove DDP tests from mamba and falcon_mamba

* add explicit error for MambaCache

* mamba2 also needs to init cache in prepare_inputs_for_generation

* ruff

* ruff

* move MambaCache to its own file

* ruff

* unprotected import fix

* another attempt to fix unprotected imports

* Revert "another attempt to fix unprotected imports"

This reverts commit 2338354.

* fixing unprotected import, attempt 3

* Update src/transformers/cache_utils.py

* ruff's fault

* fix arthur review

* modular falcon mamba

* found a hack

* fix config docs

* fix docs

* add export info

* merge modular falcon branch

* oopsie

* fix fast path failing

* new approach

* oopsie

* fix types

* Revert new pragma in modular

This reverts commit 80b1cf1.

* trying another modular workaround

* review & fix ci

* oopsie

* clear prepare_inputs on mamba/mamba2/falcon_mamba
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0