-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Refactor MambaCache
to modeling_mamba.py
#38086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, added a few tiny nits 👍
@manueldeprada have you run (after we confirm slow tests are okay, let's tag arthur) |
Ran the tests locally, there are some failures that already happened before the PR:
will run in CI once I have permissions. |
Same failures on CI than locally @gante . Let me know if those slow tests should be fixed, ignored or deleted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! There are still quite a few outstanding issue notably the imports that are not at the top
if use_cache: | ||
if cache_params_not_initialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if cache_params_not_initialized: | |
if use_cache and cache_params_not_initialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is keeping BC with respect to how generate()
did things before, now that we are moving MambaCache out of generate. I just rewrote the same behaviour without a new variable. See this comment: #38086 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry its strictly the same given the indentation no?
MambaCache
to modeling_mamba.py
(parity with Zamba)MambaCache
to modeling_mamba.py
@@ -651,6 +765,8 @@ def prepare_inputs_for_generation( | |||
): | |||
# Overwritten -- uses `cache_params` as opposed to `past_key_values` | |||
|
|||
if use_cache and cache_params is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this 2 new additions are needed for BC and affect all classes that use MambaCache: it emulates the order in which generate()
initialized MambaCache. See here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for proposing the pragma! I think we can try not renaming functions/import sources at all, I don't think any model does that today!
selective_scan_fn, | ||
causal_conv1d_fn, | ||
causal_conv1d_update, | ||
mamba_inner_fn, # modular: no_replace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (IF) we have the pragma, should not appear here
@@ -141,7 +141,12 @@ def leave_Name(self, original_node, updated_node): | |||
return updated_node | |||
|
|||
def leave_ImportFrom(self, original_node, updated_node): | |||
"""The imports from other file types (configuration, processing etc) should use original model name.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipping rename for absolute imports does not affect other models and is probably a reasonable assumption.
Relative imports need to be renamed: from .configutation_mamba import ...MambaConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small things left!
@@ -651,6 +765,8 @@ def prepare_inputs_for_generation( | |||
): | |||
# Overwritten -- uses `cache_params` as opposed to `past_key_values` | |||
|
|||
if use_cache and cache_params is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is relevant to modify the code to emulate generate. generate is the abstraction that needs to be changed
[For maintainers] Suggested jobs to run (before merge) run-slow: falcon_mamba, mamba, mamba2 |
Thanks, @ArthurZucker! I ended up refactoring the entire lmk if theres anything left! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that looks better 😉
* Refactor MambaCache to modeling_mamba.py (parity with Zamba) * ruff * fix dummies * update * update * remove mamba ref in cache tests * remove cache_implementation from tests * update * ruff * ruff * sneaky regression * model consistency * fix test_multi_gpu_data_parallel_forward * fix falcon slow tests * ruff * ruff * add sample false * try to fix slow tests * Revert "fix test_multi_gpu_data_parallel_forward" This reverts commit 66b7162. * fix tests on nvidia t4, remove dataparallel tests from mamba * ruff * remove DDP tests from mamba and falcon_mamba * add explicit error for MambaCache * mamba2 also needs to init cache in prepare_inputs_for_generation * ruff * ruff * move MambaCache to its own file * ruff * unprotected import fix * another attempt to fix unprotected imports * Revert "another attempt to fix unprotected imports" This reverts commit 2338354. * fixing unprotected import, attempt 3 * Update src/transformers/cache_utils.py * ruff's fault * fix arthur review * modular falcon mamba * found a hack * fix config docs * fix docs * add export info * merge modular falcon branch * oopsie * fix fast path failing * new approach * oopsie * fix types * Revert new pragma in modular This reverts commit 80b1cf1. * trying another modular workaround * review & fix ci * oopsie * clear prepare_inputs on mamba/mamba2/falcon_mamba
This PR moves the specialized
MambaCache
class fromcache_utils.py
tosrc/transformers/models/mamba/modeling_mamba.py
. This is preliminary work for #38077.Changes:
MambaCache
to its own file, aligning with Zamba, Bamba, etc.generate
.preparing_inits_from_generation
fromforward()
. See this comment.preparing_inits_from_generation
and only Mamba, Mamba2, and FalconMamba were doing it inforward
, which is bad for torch.compile.