8000 How to solve the error of converting Qwen onnx_model to tensorRT_model? · Issue #37408 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content
How to solve the error of converting Qwen onnx_model to tensorRT_model? #37408
@dearwind153

Description

@dearwind153

1. The transformers' Qwen ONNX model has been exported successfully.

2. Convert ONNX_model to tensorRT_model failed by trtexec.

error info

[04/10/2025-11:04:52] [E] Error[3]: IExecutionContext::setInputShape: Error Code 3: API Usage Error (Parameter check failed, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape for key_cache.1. Set dimensions are [7,8,32,128]. Expected dimensions are [7,8,1,128].)
[04/10/2025-11:04:52] [E] The engine was built with static shapes for input tensor key_cache.1 but the provided shapes do not match the static shapes!
[04/10/2025-11:04:52] [E] Inference set up failed

Due to the fact that Qwen of Transoformers utilizes the DynamicCache class to handle KVcache, The error should be attributed to DynamicCache.

### ONNX model check OK

The model is well-formed and valid!
=======================Model1 inputs:
x_s [1, 'seq_len', 1024]
attn_mask [1, 'seq_len', 'seq_len']
key_cache.1 [7, 8, 'seq_len', 128]
value_cache.1 [7, 8, 'seq_len', 128]
=======================Model1 outputs:
y_pred [1, 'seq_len', 1024]
key_cache [7, 8, 'seq_len', 128]
value_cache [7, 8, 'seq_len', 128]

export foward

def injected_forward(
    self, 
    xs: torch.Tensor,
    att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
    key_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), dtype=torch.float32),
    value_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), dtype=torch.float32)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    att_mask = ~att_mask.unsqueeze(1) * torch.finfo(xs.dtype).min
    past_key_values = DynamicCache(self.config.num_hidden_layers)

    for i in torch.arange(self.config.num_hidden_layers):
        past_key_values.key_cache[i] = key_cache[i].unsqueeze(0)
        past_key_values.value_cache[i] = value_cache[i].unsqueeze(0)
    
    past_seen_tokens =  past_key_values.get_seq_length()
    cache_position = torch.arange(past_seen_tokens, past_seen_tokens + xs.shape[1], device=xs.device)
    position_ids = cache_position.unsqueeze(0)

    hidden_states = xs
    for decoder_layer in self.layers[: self.config.num_hidden_layers]:
        layer_outputs = decoder_layer(
            hidden_states,
            attention_mask=att_mask,
            position_ids=position_ids,
            past_key_value=past_key_values,
            output_attentions=False,
            use_cache=True,
            cache_position=cache_position,
        )

        hidden_states = layer_outputs[0]

    xs = self.norm(hidden_states)
    new_key_cache = torch.cat(past_key_values.key_cache, dim=0)
    new_value_cache = torch.cat(past_key_values.value_cache, dim=0)

    return xs, new_key_cache, new_value_cache

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0