10000 [Bug] Phi-4-multimodal audio processor failed to process multiple audios with close length · Issue #38098 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content
[Bug] Phi-4-multimodal audio processor failed to process multiple audios with close length #38098
@Isotr0py

Description

@Isotr0py

System Info

None

Who can help?

@zucchini-nlp @eustlb

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When migrating the Phi-4-MM implementation to HF-format in vLLM, I found the audio processor failed to process multiple audios in similar length, which is probably a bug in audio processor.

Reproducible example:

import numpy as np
from transformers import AutoProcessor

# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, revision="refs/pr/70")

audio = [np.zeros(512), np.zeros(620)]
inputs = processor(text="<|audio|><|audio|>", audio=audio, sampling_rate=16000)

display(inputs)

This will raise error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_35/3211980912.py in <cell line: 0>()
      9 
     10 audio = [np.zeros(512), np.zeros(620)]
---> 11 inputs = processor(text="<|audio|><|audio|>", audio=audio, sampling_rate=16000)
     12 
     13 display(inputs)

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/processing_phi4_multimodal.py in __call__(self, text, images, audio, **kwargs)
    117 
    118         image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
--> 119         audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}
    120 
    121         # We pop here for images as we don't need it later

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py in __call__(self, raw_speech, sampling_rate, pad_to_multiple_of, padding, max_length, truncation, return_tensors, return_attention_mask, device, **kwargs)
    245         audio_lengths = padded_inputs.audio_lengths
    246 
--> 247         input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)
    248 
    249         feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py in _torch_extract_fbank_features(self, waveform, audio_lengths, device)
    310                 )
    311                 mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
--> 312                 masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
    313                 frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
    314         # ---

RuntimeError: output with shape [1, 1, 400] doesn't match the broadcast shape [1, 3, 400]

If we changed these two audio to length=51000 and length=51100 respectively, another error will be raised at the same line:

import numpy as np
from transformers import AutoProcessor

# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, revision="refs/pr/70")

audio = [np.zeros(51000), np.zeros(51100)]
inputs = processor(text="<|audio|><|audio|>", audio=audio, sampling_rate=16000)

display(inputs)

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_35/4075216193.py in <cell line: 0>()
      9 
     10 audio = [np.zeros(51000), np.zeros(51100)]
---> 11 inputs = processor(text="<|audio|><|audio|>", audio=audio, sampling_rate=16000)
     12 
     13 display(inputs)

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/processing_phi4_multimodal.py in __call__(self, text, images, audio, **kwargs)
    117 
    118         image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
--> 119         audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}
    120 
    121         # We pop here for images as we don't need it later

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py in __call__(self, raw_speech, sampling_rate, pad_to_multiple_of, padding, max_length, truncation, return_tensors, return_attention_mask, device, **kwargs)
    245         audio_lengths = padded_inputs.audio_lengths
    246 
--> 247         input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)
    248 
    249         feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1

/usr/local/lib/python3.11/dist-packages/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py in _torch_extract_fbank_features(self, waveform, audio_lengths, device)
    310                 )
    311                 mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
--> 312                 masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
    313                 frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
    314         # ---

RuntimeError: The size of tensor a (0) must match the size of tensor b (2) at non-singleton dimension 1

Expected behavior

However, for original processor, these audios can be processed expectedly:

import numpy as np
from transformers import AutoProcessor

# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)

audio = [(np.zeros(512), 16000), (np.zeros(620), 16000)]
inputs = processor(text="<|audio_1|><|audio_2|>", audios=audio)

print("audio_embed_sizes:", inputs.audio_embed_sizes)


audio = [(np.zeros(51000), 16000), (np.zeros(51100), 16000)]
inputs = processor(text="<|audio_1|><|audio_2|>", audios=audio)

print("audio_embed_sizes:", inputs.audio_embed_sizes)

And the outputs are reasonable:

audio_embed_sizes: tensor([1, 1])
audio_embed_sizes: tensor([40, 40])

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0