8000 Better error messages by xadupre · Pull Request #73 · sdpython/onnx-array-api · GitHub
[go: up one dir, main page]

Skip to content

Better error messages #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 7, 2024
Prev Previous commit
Next Next commit
improves robustness
  • Loading branch information
xadupre committed Feb 7, 2024
commit e38b73bf147777fd210457c7d35dfd260f0d2036
17 changes: 12 additions & 5 deletions onnx_array_api/reference/evaluator_yield.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Iterator, Optional, Tuple
from typing import Any, Dict, List, Iterator, Optional, Tuple, Union
from enum import IntEnum
import numpy as np
from onnx import ModelProto, TensorProto, ValueInfoProto
Expand Down Expand Up @@ -424,7 +424,7 @@ def generate_inputs(model: ModelProto) -> List[np.ndarray]:
def compare_onnx_execution(
model1: ModelProto,
model2: ModelProto,
inputs: Optional[List[Any]] = None,
inputs: Optional[Union[List[Any], Tuple[Dict[str, Any]]]] = None,
verbose: int = 0,
raise_exc: bool = True,
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
Expand All @@ -436,7 +436,8 @@ def compare_onnx_execution(

:param model1: first model
:param model2: second model
:param inputs: inputs to use
:param inputs: inputs to use, a list of inputs if both models have
the same number of inputs or two dictionaries, one for each model
:param verbose: verbosity
:param raise_exc: raise exception if the execution fails or stop at the error
:return: four results, a sequence of results for the first model and the second model,
Expand All @@ -446,8 +447,14 @@ def compare_onnx_execution(
print("[compare_onnx_execution] generate inputs")
if inputs is None:
inputs = generate_inputs(model1)
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
if isinstance(inputs, tuple):
assert len(inputs) == 2, f"Unexpected number {len(inputs)} of inputs."
feeds1, feeds2 = inputs
else:
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
assert isinstance(feeds1, dict), f"Unexpected type {type(feeds1)} for inputs"
assert isinstance(feeds2, dict), f"Unexpected type {type(feeds2)} for inputs"
if verbose:
print(f"[compare_onnx_execution] got {len(inputs)} inputs")
print("[compare_onnx_execution] execute first model")
Expand Down
0