-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Fix error in calculating cache_position
with past_length for Chatglm and Mamba model
#38134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…m and Mamba model Signed-off-by: Kaili Xu <kaili.xu@intel.com>
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
cc @gante for generation, but is there an issue or an explanation somewhere for why these changes are needed? |
I am working with model chatglm and mamba, |
@manueldeprada with the PR you have open for mambacache, this won't be needed, correct? |
if "ChatGLM" in self.__class__.__name__: | ||
past_length = cache[0][0].shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be needed because ChatGLM
, a custom model, also uses a custom cache format.
We don't add logic for custom model's custom choices in transformers
, my advice would be to open a PR in the Hub repo so as to fix your issue 🤗
yes, with #38086 the cache position calculation is more consistent in with mamba models. @kailixu-x could you provide a simple snippet that shows the bug for mamba? thanks for reporting!! |
elif model_kwargs.get("cache_params") is not None: | ||
cache = model_kwargs["cache_params"] | ||
past_length = 0 | ||
if hasattr(cache, "seqlen_offset"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does seqlen_offset
come from? custom MambaCache implementation? I can only find it on falcon_h1 in transformers codebase.
my modification in src/transformers/generation/utils.py :: _get_initial_cache_position() |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.