-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Closed
Description
1. The transformers' Qwen ONNX model has been exported successfully.
2. Convert ONNX_model to tensorRT_model failed by trtexec.
error info
[04/10/2025-11:04:52] [E] Error[3]: IExecutionContext::setInputShape: Error Code 3: API Usage Error (Parameter check failed, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape for key_cache.1. Set dimensions are [7,8,32,128]. Expected dimensions are [7,8,1,128].)
[04/10/2025-11:04:52] [E] The engine was built with static shapes for input tensor key_cache.1 but the provided shapes do not match the static shapes!
[04/10/2025-11:04:52] [E] Inference set up failed
Due to the fact that Qwen of Transoformers utilizes the DynamicCache class to handle KVcache, The error should be attributed to DynamicCache.
### ONNX model check OK
The model is well-formed and valid!
=======================Model1 inputs:
x_s [1, 'seq_len', 1024]
attn_mask [1, 'seq_len', 'seq_len']
key_cache.1 [7, 8, 'seq_len', 128]
value_cache.1 [7, 8, 'seq_len', 128]
=======================Model1 outputs:
y_pred [1, 'seq_len', 1024]
key_cache [7, 8, 'seq_len', 128]
value_cache [7, 8, 'seq_len', 128]
export foward
def injected_forward(
self,
xs: torch.Tensor,
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
key_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), dtype=torch.float32),
value_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), dtype=torch.float32)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
att_mask = ~att_mask.unsqueeze(1) * torch.finfo(xs.dtype).min
past_key_values = DynamicCache(self.config.num_hidden_layers)
for i in torch.arange(self.config.num_hidden_layers):
past_key_values.key_cache[i] = key_cache[i].unsqueeze(0)
past_key_values.value_cache[i] = value_cache[i].unsqueeze(0)
past_seen_tokens = past_key_values.get_seq_length()
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + xs.shape[1], device=xs.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = xs
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=att_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
xs = self.norm(hidden_states)
new_key_cache = torch.cat(past_key_values.key_cache, dim=0)
new_value_cache = torch.cat(past_key_values.value_cache, dim=0)
return xs, new_key_cache, new_value_cache
Metadata
Metadata
Assignees
Labels
No labels