8000 [transformers x vLLM] standardize processors by zucchini-nlp · Pull Request #37915 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

[transformers x vLLM] standardize processors #37915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 27, 2025

Conversation

zucchini-nlp
Copy link
Member
@zucchini-nlp zucchini-nlp commented May 1, 2025

What does this PR do?

Part of #37780. The design was tested on different model types:

Oke, I verified that the inference works for all models, unless I forgot about some new ones. Here is the list I tested. A few models (blip2, gotOcr, gemma3) won't be supported in the first release. Gemma3 is already planned after we merge first version of integration, it requires bigger changes for us to make bidirectional attention with token_type_ids

model_example_map = {
    "aria": run_aria,
    "aya_vision": run_aya_vision,
    "chameleon": run_chameleon, # NOTE: ready but needs to add suppress token in hub saved generation config
    "emu3": run_emu,
    "fuyu": run_fuyu, # Almost there, needs new attn interface for Persimmon LM backend in new PR
    "got_ocr": run_got_ocr, # More complex as it needs to add boxes/etc. Might support later
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
    "llava": run_llava,
    "pixtral": run_pixtral,
    "llava_next": run_llava_next,
    "llava_onevision": run_llava_onevision,
    "mllama": run_mllama, # Cross attn not yet supported
    "mistral3": run_mistral3,
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "qwen2_vl": run_qwen2_vl,
    "qwen2_5_vl": run_qwen2_5_vl,
    "vipllava": run_vipllava,
}

I will do a subsequent PR with the rest of changes for modeling code. That's pretty much all left

@github-actions github-actions bot marked this pull request as draft May 1, 2025 13:04
Copy link
Contributor
github-actions bot commented May 1, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@zucchini-nlp zucchini-nlp marked this pull request as ready for review May 1, 2025 13:04
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! 🤗

Comment on lines +1020 to +1024
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing numpy for no torch deps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I know that we don't have any new processors that need to supports jax or TF, so torch probably is already installed for users. I did this just for consistency with all the other processor

@@ -29,6 +30,7 @@
ImageInput,
make_flat_list_of_images,
)
from ..got_ocr2.image_processing_got_ocr2 import get_optimal_tiled_canvas
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have not checked this one, but is it pre-computed? if not we should probably!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is computed based on input image, similar what llava-next does. Resize and divide to patches while keeping as much information as possible, I moved patching related logic to image processor class now

Comment on lines +160 to +162
self.row_col_ids = [
tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") for i in range(6) for j in range(6)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would rather just hardcode if there are only 6!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is already hardcoded to be 6. You mean in hub configs or where?

Comment on lines +362 to +376
array_ids = np.array(inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
for i, seq_lengths in enumerate(batch_image_seq_lengths):
image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
j = 0
for seq_len in seq_lengths:
if j >= len(image_start_positions):
break
start = image_start_positions[j]
end = start + seq_len
mm_token_type_ids[i, start:end] = 1
j = np.searchsorted(image_start_positions, end)

inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can definitely be simplified no?
It's weird!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I would love to simply mask out certain token ids. The issue is that idefics special image tokens include \n and we can't mask it by id

So the expanded seq is smth like {image_id * N}{fake_wrapper_id}\n\n{next_col_id}{image_id * N}.... We could get out by masking only image ids but that means vLLM chunked prefill will fail. vLLM assumes to get contiguous positions for a single image

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though for idefics specifically the inference freezes forever when input length is higher than "max-tokens-allowed-per-batch". I will look into it more with Harry

@zucchini-nlp zucchini-nlp changed the title [WIP] standardize processors for vLLM [transformers x vLLM] standardize processors May 22, 2025
@zucchini-nlp
Copy link
Member Author

Oke, I verified that the inference works for all models, unless I forgot about some new ones. Here is the list I tested. A few models (blip2, gotOcr, gemma3) won't be supported in the first release. Gemma3 is already planned after we merge first version of integration, it requires bigger changes for us to make bidirectional attention with token_type_ids

model_example_map = {
    "aria": run_aria,
    "aya_vision": run_aya_vision,
    "chameleon": run_chameleon, # NOTE: DONE but needs to add suppress token in hub generation config
    "emu3": run_emu,
    "fuyu": run_fuyu, # Almost there, needs new attn interface for Persimmon LM backend in new PR
    "got_ocr": run_got_ocr, # More complex as it needs to add boxes/etc. Might support later
    "idefics3": run_idefics3,
    "internvl_chat": run_internvl,
    "llava": run_llava,
    "pixtral": run_pixtral,
    "llava_next": run_llava_next,
    "llava_onevision": run_llava_onevision,
    "mllama": run_mllama, # Cross attn not yet supported
    "mistral3": run_mistral3,
    "paligemma": run_paligemma,
    "paligemma2": run_paligemma2,
    "qwen2_vl": run_qwen2_vl,
    "qwen2_5_vl": run_qwen2_5_vl,
    "vipllava": run_vipllava,
}

I will do clean up of this PR and open a subsequent PR with the rest of changes for modeling code. That's pretty much all left

@@ -131,6 +132,11 @@ def __init__(
self.img_line_break_token = img_line_break_token
self.tile_token = tile_token
self.tile_global_token = tile_global_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aya vision was implemented to have different placeholder tokens in the prompt and after processing. image_token is replaced by img_patch_token, which should not happen. We make a lot of assumption in the codebase that the input placeholder is the same one which the model uses

So two options are:

  1. Refactor it out to be correct and use only img_patch_token which means we need to update Hub chat template. Might be very breaking
  2. Current solution, not intuitive but image_token_id isn't used anywhere thus not breaking

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, one comment is that I would rather use explicit names, like when it's not really multimodal I would not call the dict multimodel, just for explicity!
Otherwise nice abstraction 🤗

@zucchini-nlp zucchini-nlp merged commit 9e1017b into huggingface:main May 27, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0