8000 Fix error in calculating `cache_position` with past_length for Chatglm and Mamba model by kailixu-x · Pull Request #38134 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

Fix error in calculating cache_position with past_length for Chatglm and Mamba model #38134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

kailixu-x
Copy link

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…m and Mamba model

Signed-off-by: Kaili Xu <kaili.xu@intel.com>
@github-actions github-actions bot marked this pull request as draft May 15, 2025 03:00
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@kailixu-x kailixu-x marked this pull request as ready for review May 15, 2025 03:29
@github-actions github-actions bot requested a review from gante May 15, 2025 03:29
@Rocketknight1
Copy link
Member

cc @gante for generation, but is there an issue or an explanation somewhere for why these changes are needed?

@kailixu-x
Copy link
Author

e or an explanation somewhe

I am working with model chatglm and mamba,
https://huggingface.co/THUDM/chatglm3-6b,
https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1,
the 2 models' past_length calculation is not handled correctly, the 1st one use shape[0], the 2nd one use cache_param, so i add the WR.

@gante
Copy link
Member
gante commented May 22, 2025

@manueldeprada with the PR you have open for mambacache, this won't be needed, correct?

Comment on lines +1731 to +1732
if "ChatGLM" in self.__class__.__name__:
past_length = cache[0][0].shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be needed because ChatGLM, a custom model, also uses a custom cache format.

We don't add logic for custom model's custom choices in transformers, my advice would be to open a PR in the Hub repo so as to fix your issue 🤗

@manueldeprada
Copy link
Contributor
manueldeprada commented May 22, 2025

yes, with #38086 the cache position calculation is more consistent in with mamba models. @kailixu-x could you provide a simple snippet that shows the bug for mamba? thanks for reporting!!

elif model_kwargs.get("cache_params") is not None:
cache = model_kwargs["cache_params"]
past_length = 0
if hasattr(cache, "seqlen_offset"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does seqlen_offset come from? custom MambaCache implementation? I can only find it on falcon_h1 in transformers codebase.

@kailixu-x
Copy link
Author

my modification in src/transformers/generation/utils.py :: _get_initial_cache_position()
my mamba model is mamba and mamba2
my use case is: call generate() with my mamba2 cache,
in Mamba2Output layput, it is "cache_param", not "past_key_values" -> https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/mamba2/modeling_mamba2.py#L762
seqlen_offset comes from->
https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/mamba2/modeling_mamba2.py#L132

@kailixu-x kailixu-x closed this 71A3 Jul 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0