From bd1dbcb0f08842e71fd8e8de03f7d53a99fffb85 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Feb 2025 11:43:54 -0800 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- test/onnx/torchlib/README.md | 80 +++ test/onnx/torchlib/ops_test_common.py | 701 ++++++++++++++++++++++++++ test/onnx/torchlib/ops_test_data.py | 685 +++++++++++++++++++++++++ 3 files changed, 1466 insertions(+) create mode 100644 test/onnx/torchlib/README.md create mode 100644 test/onnx/torchlib/ops_test_common.py create mode 100644 test/onnx/torchlib/ops_test_data.py diff --git a/test/onnx/torchlib/README.md b/test/onnx/torchlib/README.md new file mode 100644 index 00000000000000..0ea8c6c524d474 --- /dev/null +++ b/test/onnx/torchlib/README.md @@ -0,0 +1,80 @@ +# Test op correctness by comparing with PyTorch results using OpInfo + +`OpInfo` is PyTorch's standard mechanism for composing test data for operators. +Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362. + +## Usage + +```bash +# All +python -m pytest test_ops.py + +# To run tests on a specific operator (e.g. torch.ceil): +python -m pytest test_ops.py -k ceil + +# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention): +python -m pytest test_ops.py -k nn_functional_scaled_dot_product_attention +``` + +### Environment variables + +1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults +in onnxruntime by running the inference sessions in a separate process. +2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. + + ```bash + CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/torchlib/test_ops.py -k div_mode_int + ``` + +## How to add a new operator test + +See _usage_ in [`ops_test_data.py`](./ops_test_data.py) + +## How to add custom OpInfo tests + +Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it. + +Follow the steps below to create new OpInfo tests: + +1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py) + + ```py + opinfo_core.OpInfo( + "ops.aten.slice_scatter", + aten_name="slice_scatter", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_slice_scatter, + supports_out=False, + ), + ``` + + - The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`. + - Follow existing examples to specify the `dtypes` you want to test the op on. + - Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068. + + ```py + opinfo_core.OpInfo( + "ops.aten.bernoulli.p_deterministic", + op=torch.ops.aten.bernoulli.p, + ``` + + The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name. + +2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268) + 1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with + + ```py + yield opinfo_core.SampleInput(input, args=(...), kwargs={...}) + ``` + + `input` is the first arg. The rest of the args are in `args`. +3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py) + 1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116) + + ```py + TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter) + ``` + + You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590). + +Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails. diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py new file mode 100644 index 00000000000000..8b49e94e3da64b --- /dev/null +++ b/test/onnx/torchlib/ops_test_common.py @@ -0,0 +1,701 @@ +# Owner(s): ["module: onnx"] +"""Common utils for testing operators.""" + +from __future__ import annotations + +import contextlib +import copy +import dataclasses +import multiprocessing +import os +import pprint +import sys +import unittest +import warnings +from typing import ( + Any, + Callable, + Collection, + Iterable, + Mapping, + Optional, + Sequence, + TypeVar, +) + +import error_reproduction +import numpy as np + +import onnx +import onnxruntime as ort +import onnxruntime.capi.onnxruntime_pybind11_state +import onnxscript +import onnxscript.evaluator +import pytest +from onnxscript import ir + +import torch +from torch.onnx._internal.exporter import _building, _ir_passes, _tensors +from torch.testing._internal.opinfo import core as opinfo_core + + +T = TypeVar("T") + + +# Convenience tuples for creating dtype lists when skipping or xfailing tests + +BOOL_TYPES = (torch.bool,) + +INT_TYPES = ( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, +) + +FLOAT_TYPES = ( + torch.float16, + torch.float32, + torch.float64, +) + +TEST_OPSET_VERSION = 18 +IS_MACOS = sys.platform.startswith("darwin") +IS_WINDOWS = os.name == "nt" + + +@dataclasses.dataclass +class DecorateMeta: + """A dataclass for storing information about a test case to skip or xfail. + + Adapted from functorch: functorch/test/common_utils.py + """ + + op_name: str + variant_name: str + decorator: Callable[..., Any] + dtypes: Optional[Collection[torch.dtype]] + device_type: Optional[str] + reason: str + test_behavior: str + matcher: Optional[Callable[[Any], bool]] = None + enabled_if: bool = True + # The test_class_name to apply the decorator to. If None, the decorator is + # applied to all test classes. + test_class_name: Optional[str] = None + + +def xfail( + op_name: str, + variant_name: str = "", + *, + reason: str, + dtypes: Optional[Collection[torch.dtype]] = None, + device_type: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, + enabled_if: bool = True, + test_class_name: Optional[str] = None, +) -> DecorateMeta: + """Expects an OpInfo test to fail. + + Args: + op_name: The name of the operator. + variant_name: Optional OpInfo variant_test_name. + reason: The reason for the failure. + dtypes: The dtypes to expect the failure. + device_type: Device type. E.g. "cpu", "cuda". + matcher: A function that matches the test sample input. It is used only when + the xfail is in the SKIP_XFAIL_SUBTESTS list. + enabled_if: Whether the xfail is enabled. + test_class_name: The test class name to apply the xfail to. If None, the + xfail is applied to all test classes. + """ + return DecorateMeta( + op_name=op_name, + variant_name=variant_name, + decorator=unittest.expectedFailure, + dtypes=dtypes, + device_type=device_type, + matcher=matcher, + reason=reason, + enabled_if=enabled_if, + test_class_name=test_class_name, + test_behavior="xfail", + ) + + +def skip( + op_name: str, + variant_name: str = "", + *, + reason: str, + dtypes: Optional[Collection[torch.dtype]] = None, + device_type: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, + enabled_if: bool = True, + test_class_name: Optional[str] = None, +) -> DecorateMeta: + """Skips an OpInfo test. + + Args: + op_name: The name of the operator. + variant_name: Optional OpInfo variant_test_name. + reason: The reason for skipping. + dtypes: The dtypes to skip. + device_type: Device type. E.g. "cpu", "cuda". + matcher: A function that matches the test sample input. It is used only when + the skip is in the SKIP_XFAIL_SUBTESTS list. + enabled_if: Whether the skip is enabled. + test_class_name: The test class name to apply the skip to. If None, the skip + is applied to all test classes. + """ + return DecorateMeta( + op_name=op_name, + variant_name=variant_name, + decorator=unittest.skip(f"Skip: {reason}"), + dtypes=dtypes, + device_type=device_type, + reason=reason, + matcher=matcher, + enabled_if=enabled_if, + test_class_name=test_class_name, + test_behavior="skip", + ) + + +def add_decorate_info( + all_opinfos: Sequence[opinfo_core.OpInfo], + test_class_name: str, + base_test_name: str, + skip_or_xfails: Iterable[DecorateMeta], +) -> Callable[[T], T]: + """Decorates OpInfo tests with decorators based on the skip_or_xfails list.""" + ops_mapping = {(info.name, info.variant_test_name): info for info in all_opinfos} + for decorate_meta in skip_or_xfails: + opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) + if opinfo is None and not decorate_meta.enabled_if: + # If the OpInfo doesn't exist and it is not enabled, we skip the OpInfo + # because it could be an OpInfo that is in torch-nightly but not older versions. + continue + assert ( + opinfo is not None + ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + decorators = list(opinfo.decorators) + new_decorator = opinfo_core.DecorateInfo( + decorate_meta.decorator, + decorate_meta.test_class_name or test_class_name, + base_test_name, + dtypes=decorate_meta.dtypes, + device_type=decorate_meta.device_type, + active_if=decorate_meta.enabled_if, + ) + decorators.append(new_decorator) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + + return wrapped + + +def duplicate_opinfo( + opinfos: list[opinfo_core.OpInfo], name: str, new_names: tuple[str, ...] +): + """Duplicate an opinfo in the opinfo database and give it a new name.""" + duplicated = [] + all_info_names = {opinfo.name for opinfo in opinfos} + for opinfo in opinfos: + if opinfo.name == name: + for new_name in new_names: + if new_name in all_info_names: + # NOTE: Avoid duplicating an opinfo that already exists in the database. + # New opinfos are expected to be added in torch-nightly. + warnings.warn( + f"OpInfo {new_name} already exists in the database.", + stacklevel=1, + ) + continue + new_opinfo = copy.deepcopy(opinfo) + new_opinfo.name = new_name + duplicated.append(new_opinfo) + opinfos.extend(duplicated) + + +def duplicate_opinfo_for_prims( + opinfos: list[opinfo_core.OpInfo], name: str, prims_name: str | None = None +): + """Duplicate an opinfo in the opinfo database for a prims op. + + The function sets the new OpInfo to use the variation torch.ops.prims. + The new OpInfo will have the name "prims_{prims_name}" where `prims_name` is the + name of the prims op. If `prims_name` is None, it will be set to "prims_{name}". + + Args: + opinfos: The list of opinfo_core.OpInfo to add the new opinfo to. + name: The name of the opinfo to duplicate. + prims_name: The name of the prims op. If None, it will be set to `name`. + """ + if prims_name is None: + prims_name = name + # The name of the new OpInfo + new_name = f"prims_{prims_name}" + all_info_names = {opinfo.name for opinfo in opinfos} + for opinfo in opinfos: + if opinfo.name == name: + if new_name in all_info_names: + # NOTE: Avoid duplicating an opinfo that already exists in the database. + warnings.warn( + f"OpInfo {new_name} already exists in the database.", stacklevel=1 + ) + continue + new_opinfo = copy.deepcopy(opinfo) + new_opinfo.name = new_name + new_opinfo.op = getattr(torch.ops.prims, prims_name) + opinfos.append(new_opinfo) + return + raise RuntimeError(f"OpInfo '{name}' not found in the database.") + + +TORCH_TYPE_TO_ONNX = { + torch.bool: onnx.TensorProto.BOOL, + torch.uint8: onnx.TensorProto.UINT8, + torch.int8: onnx.TensorProto.INT8, + torch.int16: onnx.TensorProto.INT16, + torch.int32: onnx.TensorProto.INT32, + torch.int64: onnx.TensorProto.INT64, + torch.float16: onnx.TensorProto.FLOAT16, + torch.float32: onnx.TensorProto.FLOAT, + torch.float64: onnx.TensorProto.DOUBLE, + torch.complex64: onnx.TensorProto.COMPLEX64, + torch.complex128: onnx.TensorProto.COMPLEX128, + torch.bfloat16: onnx.TensorProto.BFLOAT16, +} + + +def convert_tensor_to_numpy(input: Any) -> Any: + if isinstance(input, torch.Tensor): + if torch.is_complex(input): + # from complex to real representation + input = torch.view_as_real(input) + return input.detach().cpu().numpy() + if isinstance(input, complex): + return torch.view_as_real(torch.tensor(input)).detach().cpu().numpy() + if isinstance(input, list): + if len(input) == 0: + return np.array((), dtype=np.int64) + if any(isinstance(x, torch.Tensor) for x in input): + # The list can be Optional[Tensor], e.g. [None, Tensor, None] etc. + return [convert_tensor_to_numpy(x) for x in input] + if isinstance(input[0], bool): + return np.array(input, dtype=np.bool_) + + # Just a sequence of numbers + if isinstance(input[0], int): + return np.array(input, dtype=np.int64) + if isinstance(input[0], float): + return np.array(input) + + return input + + +def convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]: + """Converts kwargs to be compatible with ONNX Runtime.""" + new_kwargs = {} + for key, value in kwargs.items(): + if key == "device": + continue + if key == "dtype": + value = TORCH_TYPE_TO_ONNX[value] + if isinstance(value, torch.Tensor): + value = np.array(value.cpu()) + new_kwargs[key] = value + return new_kwargs + + +class OrtAbortedError(RuntimeError): + """ONNX Runtime Aborted.""" + + +def _ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]): + """Run a model with ONNX Runtime.""" + + # Disable all ORT optimizations + session_options = onnxruntime.SessionOptions() + session_options.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + ) + session = ort.InferenceSession( + serialized_model, session_options, providers=("CPUExecutionProvider",) + ) + return session.run(None, ort_inputs) + + +def _ort_session_run_return_dict( + serialized_model: bytes, ort_inputs: Mapping[str, Any], return_dict +) -> None: + """Run a model with ONNX Runtime and store the results in return_dict.""" + + try: + return_dict["results"] = _ort_session_run(serialized_model, ort_inputs) + return_dict["error"] = None + except Exception as e: # pylint: disable=broad-except + return_dict["results"] = None + return_dict["error"] = e + + +def _safe_ort_session_run(serialized_model: bytes, ort_inputs: Mapping[str, Any]): + """Run a model with ONNX Runtime in a separate process. + + Args: + serialized_model: Serialized ONNX model proto. + ort_inputs: Inputs to the model. + + Returns: + The inference result. + + Raises: + OrtAbortedError if the process did not execute successfully. + """ + manager = multiprocessing.Manager() + return_dict = manager.dict() + process = multiprocessing.Process( + target=_ort_session_run_return_dict, + args=(serialized_model, ort_inputs, return_dict), + ) + process.start() + process.join() + process.close() + if not return_dict: + raise OrtAbortedError + if return_dict["error"] is not None: + raise return_dict["error"] + return return_dict["results"] + + +def _format_model_and_input_information(onnx_model, inputs): + return ( + f"Inputs:\n{pprint.pformat(inputs)}\nModel:\n{onnx.printer.to_text(onnx_model)}" + ) + + +_TORCH_DTYPE_TO_ONNX_STRING = { + torch.bool: "tensor(bool)", + torch.uint8: "tensor(uint8)", + torch.int8: "tensor(int8)", + torch.int16: "tensor(int16)", + torch.int32: "tensor(int32)", + torch.int64: "tensor(int64)", + torch.float16: "tensor(float16)", + torch.float32: "tensor(float)", + torch.float64: "tensor(double)", + torch.complex64: "tensor(complex64)", + torch.complex128: "tensor(complex128)", + torch.bfloat16: "tensor(bfloat16)", +} + +_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { + torch.bfloat16: ir.DataType.BFLOAT16, + torch.bool: ir.DataType.BOOL, + torch.complex128: ir.DataType.COMPLEX128, + torch.complex64: ir.DataType.COMPLEX64, + torch.float16: ir.DataType.FLOAT16, + torch.float32: ir.DataType.FLOAT, + torch.float64: ir.DataType.DOUBLE, + torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, + torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, + torch.float8_e5m2: ir.DataType.FLOAT8E5M2, + torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, + torch.int16: ir.DataType.INT16, + torch.int32: ir.DataType.INT32, + torch.int64: ir.DataType.INT64, + torch.int8: ir.DataType.INT8, + torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, +} + + +def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: + """Checks if the dtype is compatible with the schema. + + When a dtype is "compatible" with the schema, it means we can use the dtype + to create sample inputs by OpInfo to test the ONNX function and expect outputs to match. + + Args: + dtype: The torch dtype used to create sample inputs by OpInfo. + schema: The ONNX schema of the function. + + Returns: + True if the dtype is compatible with the schema. + """ + if not schema.inputs: + # If there are no inputs, we can't check compatibility. Assume it is compatible. + # e.g. aten_randn has only attributes. + return True + if schema.inputs[0].name not in {"self", "input"}: + # If the name of the first input is not "self" or "input", + # it is usually an input that is not of the same type as the output. + # We assume support in this case. + # + # For example, `aten_ones(size: IntType, dtype: int = FLOAT.dtype)` + # has the first input as `size`, which is an integer, but it can support + # any dtype. + return True + + # Otherwise we check the type constraints of the first input. + # For example, when dtype=torch.float32, and the op being tested has the schema + # ``` + # OpSchema( + # name='aten_abs', + # domain='pkg.onnxscript.torch_lib', + # since_version=1, + # doc='abs(Tensor self) -> Tensor', + # type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', + # allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', + # 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', + # 'tensor(bfloat16)'], description='')], + # inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', + # description='', param_option=, + # is_homogeneous=True, min_arity=1, + # differentiation_category=)], + # outputs=[OpSchema.FormalParameter(name='return_val', + # type_str='TReal', description='', + # param_option=, is_homogeneous=True, + # min_arity=1, differentiation_category=)], + # attributes={} + # ) + # ``` + # we see the first input type is "TReal", corresponding to the type constraint + # with allowed types ['tensor(float)', 'tensor(int8)', 'tensor(int16)', + # 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', + # 'tensor(bfloat16)']. + # Since torch.float32 (tensor(float)) is in the allowed types, we return True. + + first_input_type_name = schema.inputs[0].type_str + # Find the type constraint for the first input by matching the parameter name + first_input_type_constraint = next( + ( + x + for x in schema.type_constraints + if first_input_type_name in x.type_param_str + ), + None, + ) + assert first_input_type_constraint is not None + allowed_type_strs = first_input_type_constraint.allowed_type_strs + # Here we consider seq(tensor(float)) compatible with tensor(float) as well + return any( + _TORCH_DTYPE_TO_ONNX_STRING[dtype] in type_str for type_str in allowed_type_strs + ) + + +def graph_executor( + test_name: str, + outputs: Sequence[Any], +) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]: + """Eagerly executes a function.""" + + def _capture_graph_and_evaluate_torch_script_evaluator( + function: Callable, args, kwargs + ) -> tuple[Any, onnx.ModelProto]: + """Captures the graph of a function and evaluates it using TorchScriptEvaluator.""" + + # Initialize the ONNX graph + graph = ir.Graph( + (), + (), + nodes=(), + opset_imports={"": 18, "pkg.torch.onnx": 1}, + name="main_graph", + ) + opset = onnxscript.opset18 + tracer = _building.OpRecorder(opset, {}) + ort_inputs = {} + onnxscript_args: list[Any] = [] + onnxscript_kwargs = {} + for i, arg in enumerate(args): + if isinstance(arg, np.ndarray): + input_name = f"input_{i}" + input = _tensors.SymbolicTensor( + opset=opset, + name=input_name, + shape=ir.Shape(arg.shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(arg).dtype]), + ) + graph.inputs.append(input) + onnxscript_args.append(input) + ort_inputs[input_name] = arg + elif isinstance(arg, (list, tuple)): + # str is also a sequence but we do not want to treat it as a tensor + sequence_input = [] + for j, subarg in enumerate(arg): + if isinstance(subarg, np.ndarray): + input_name = f"input_{i}_{j}" + tensor = torch.tensor(subarg) + input = _tensors.SymbolicTensor( + opset=opset, + name=input_name, + shape=ir.Shape(tensor.shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[tensor.dtype]), + ) + graph.inputs.append(input) + sequence_input.append(input) + ort_inputs[input_name] = subarg + else: + # Include non-numpy inputs as-is + # For example, it could be a None value that we want to keep + sequence_input.append(subarg) + onnxscript_args.append(sequence_input) + else: + onnxscript_args.append(arg) + for key, value in kwargs.items(): + if isinstance(value, np.ndarray): + input = _tensors.SymbolicTensor( + opset=opset, + name=key, + shape=ir.Shape(torch.tensor(value).shape), + type=ir.TensorType(_TORCH_DTYPE_TO_ONNX[torch.tensor(value).dtype]), + ) + graph.inputs.append(input) + ort_inputs[key] = value + onnxscript_kwargs[key] = input + else: + onnxscript_kwargs[key] = value + + with onnxscript.evaluator.default_as(tracer): + symbolic_outputs = function(*onnxscript_args, **onnxscript_kwargs) + if not isinstance(symbolic_outputs, Sequence): + symbolic_outputs = (symbolic_outputs,) + + # We need to set the size of the output tensors for the ONNX model to be valid + for output, symbolic_output in zip(outputs, symbolic_outputs): + if isinstance(output, Sequence): + # Output is a sequence + elem_dtype = _TORCH_DTYPE_TO_ONNX[output[0].dtype] + symbolic_output.type = ir.SequenceType(ir.TensorType(elem_dtype)) + continue + output = ( + output + if isinstance(output, torch.Tensor) + else torch.tensor(output, device="cpu") + ) + symbolic_output.shape = ir.Shape(output.shape) + symbolic_output.dtype = _TORCH_DTYPE_TO_ONNX[output.dtype] + + graph.outputs.extend(symbolic_outputs) + graph.extend(tracer.nodes) + onnx_model = ir.Model(graph, ir_version=10, producer_name="torch_test") + for identifier, onnxscript_function in tracer.functions.items(): + if identifier in onnx_model.functions: + continue + if isinstance(onnxscript_function, ir.Function): + ir_function = onnxscript_function + else: + # TODO: Get IR function directly when onnxscript is updated + proto = onnxscript_function.to_function_proto() + ir_function = ir.serde.deserialize_function(proto) + onnx_model.functions[identifier] = ir_function + _ir_passes.add_torchlib_common_imports(onnx_model) + _ir_passes.add_opset_imports(onnx_model) + # Make sure the model is valid + model_proto = ir.to_proto(onnx_model) + try: + onnx.checker.check_model(model_proto, full_check=True) + except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: + raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e + model_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True) + try: + if ( + os.environ.get("CATCH_ORT_SEGFAULT") == "1" + or os.environ.get("CREATE_REPRODUCTION_REPORT") == "1" + ): + # Use an individual process to run ONNX Runtime to catch segfaults + return _safe_ort_session_run( + model_proto.SerializeToString(), ort_inputs + ), model_proto + + return _ort_session_run( + model_proto.SerializeToString(), ort_inputs + ), model_proto + except ( + # pylint: disable=c-extension-no-member + onnxruntime.capi.onnxruntime_pybind11_state.Fail, + onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException, + onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument, + onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, + onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented, + # pylint: enable=c-extension-no-member + ) as e: + if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": + error_reproduction.create_reproduction_report( + test_name, + model_proto, + ort_inputs, + e, + "test/onnx/torchlib/test_ops.py", + ) + raise RuntimeError( + "ONNX Runtime failed to evaluate:\n" + + _format_model_and_input_information(model_proto, ort_inputs) + ) from e + except OrtAbortedError as e: + if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": + # Save the model and inputs to a file for reproduction + error_reproduction.create_reproduction_report( + test_name, + model_proto, + ort_inputs, + e, + "test/onnx/torchlib/test_ops.py", + ) + raise OrtAbortedError( + "ONNX Runtime aborted:\n" + + _format_model_and_input_information(model_proto, ort_inputs) + ) from e + except Exception as e: + if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": + error_reproduction.create_reproduction_report( + test_name, + model_proto, + ort_inputs, + e, + "test/onnx/torchlib/test_ops.py", + ) + raise + + return _capture_graph_and_evaluate_torch_script_evaluator + + +@contextlib.contextmanager +def normal_xfail_skip_test_behaviors( + test_behavior: Optional[str] = None, reason: Optional[str] = None +): + """This context manager is used to handle the different behaviors of xfail and skip. + + Args: + test_behavior (optional[str]): From DecorateMeta name, can be 'skip', 'xfail', or None. + reason (optional[str]): The reason for the failure or skip. + + Raises: + e: Any exception raised by the test case if it's not an expected failure. + """ + + # We need to skip as soon as possible, as SegFault might also be a case. + if test_behavior == "skip": + pytest.skip(reason=reason) + + try: + yield + # We could use `except (AssertionError, RuntimeError, ...) as e:`, but it needs + # to go over all test cases to find the right exception type. + except Exception: # pylint: disable=broad-exception-caught + if test_behavior is None: + raise + if test_behavior == "xfail": + pytest.xfail(reason=reason) + else: + if test_behavior == "xfail": + pytest.fail("Test unexpectedly passed") diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py new file mode 100644 index 00000000000000..bbb86d810cab91 --- /dev/null +++ b/test/onnx/torchlib/ops_test_data.py @@ -0,0 +1,685 @@ +# Owner(s): ["module: onnx"] +"""Test op correctness by comparing with PyTorch results. + +## Usage + +1. Set the env var CATCH_ORT_SEGFAULT to catch segfaults from ONNX Runtime. + +## How to add a new operator test + +This test use PyTorch's OpInfo mechanism to generate test cases for each operator. +You may find all OpInfos in https://github.com/pytorch/pytorch/blob/7ec0d6f006fdd2c9b978dc6aa4923144684a3f51/torch/testing/_internal/common_methods_invocations.py#L8804 + +1. To enable test cases for an operator + Add a `TorchLibOpInfo` entry to `TORCH_LIB_OPINFO` in `ops_test_data.py`. + Specify `complex` if the function is designed for complex inputs. + + The `op_info_name` in `TorchLibOpInfo` needs to be unique in the TORCH_LIB_OPINFO + list, but complex=True ops can share the same name with non-complex ops + because they are tested separately. + +2. Add `.skip` and/or `.xfail` to skip or xfail tests. + Prefer xfail over skip when possible because that allows us to monitor the behavior + and update the test will it passes. + + 2a. If a test is now failing because of xpass, because some previous errors + are now fixed, removed the corresponding xfail. + +3. If sample inputs of the OpInfo needs to be adjusted to fit the aten signature, create an input +wrangler function. See `_mean_input_wrangler` for an example. + +4. To test different ONNX functions that are registered as overloads of the same + op, use `ops_test_common.duplicate_opinfo` to create new OpInfo with new names and map each + to one overload. +""" +# flake8: noqa + +from __future__ import annotations + +import copy +import dataclasses +import functools +from typing import Any, Callable, Collection, Optional +from typing_extensions import Self + +import numpy as np +import ops_test_common + +import torch +from torch.testing._internal import common_methods_invocations +from torch.testing._internal.opinfo import definitions as opinfo_definitions + + +# Create a copy of the op_db to modify +OPS_DB = copy.deepcopy(common_methods_invocations.op_db) + +# Append extra op_db into the op database for testing +OPS_DB.extend(opinfo_definitions.signal.op_db) + + +@dataclasses.dataclass +class TorchLibOpInfo: + """A dataclass to store the information to test an torchlib op.""" + + # The name of the op_info, e.g. "add" + op_info_name: str + # The torchlib ONNX Function to test + op: Callable[..., Any] + # The input wrangler function to adjust the input to fit the aten signature + input_wrangler: Optional[ + Callable[[list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]] + ] = None + # Whether the op is non-deterministic + nondeterministic: bool = False + # Whether to compare the shape only for the output[index] + # For example: (1,2) means compare value for output[0] and shape for output[1] and [2] + # We may be able to combine this with the nondeterministic option + compare_shape_only_for_output: tuple[int, ...] = () + # Whether the function is designed for complex inputs + complex: bool = False + # The acceptable tolerance of the inference result difference between PyTorch and ORT. + # Format: {dtype: (rtol, atol)}. + # For example: {torch.float16: (1e-3, 1e-3)} + tolerance: dict[torch.dtype, tuple[float, float]] = dataclasses.field( + default_factory=dict + ) + # Expected skips or fails for the test and/or subtests + skips_or_fails: list[ops_test_common.DecorateMeta] = dataclasses.field( + default_factory=list + ) + + def get_tolerance(self, dtype: torch.dtype) -> tuple[float | None, float | None]: + """Returns the (rtol, atol) tolerance for the given dtype.""" + if (tolerance := self.tolerance.get(dtype)) is not None: + return tolerance + + # Use the PyTorch default if not specified + # https://pytorch.org/docs/stable/testing.html + return (None, None) + + def skip( + self, + variant_name: str = "", + *, + reason: str, + dtypes: Optional[Collection[torch.dtype]] = None, + device_type: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, + enabled_if: bool = True, + test_class_name: Optional[str] = None, + ) -> Self: + """Skips an OpInfo test. + + Args: + variant_name: Optional OpInfo variant_test_name. + reason: The reason for skipping. + dtypes: The dtypes to skip. + device_type: Device type. E.g. "cpu", "cuda". + matcher: A function that matches the test sample input. It is used only when + the skip is in the SKIP_XFAIL_SUBTESTS list. + enabled_if: Whether the skip is enabled. + test_class_name: The test class name to apply the skip to. If None, the skip + is applied to all test classes. + """ + self.skips_or_fails.append( + ops_test_common.skip( + self.op_info_name, + variant_name, + reason=reason, + dtypes=dtypes, + device_type=device_type, + matcher=matcher, + enabled_if=enabled_if, + test_class_name=test_class_name, + ) + ) + return self + + def xfail( + self, + variant_name: str = "", + *, + reason: str, + dtypes: Optional[Collection[torch.dtype]] = None, + device_type: Optional[str] = None, + matcher: Optional[Callable[[Any], Any]] = None, + enabled_if: bool = True, + test_class_name: Optional[str] = None, + ) -> Self: + """Expects an OpInfo test to fail. + + Args: + variant_name: Optional OpInfo variant_test_name. + reason: The reason for the failure. + dtypes: The dtypes to expect the failure + device_type: Device type. E.g. "cpu", "cuda".. + matcher: A function that matches the test sample input. It is used only when + the xfail is in the SKIP_XFAIL_SUBTESTS list. + enabled_if: Whether the xfail is enabled. + test_class_name: The test class name to apply the xfail to. If None, the + xfail is applied to all test classes. + """ + self.skips_or_fails.append( + ops_test_common.xfail( + self.op_info_name, + variant_name, + reason=reason, + dtypes=dtypes, + device_type=device_type, + matcher=matcher, + enabled_if=enabled_if, + test_class_name=test_class_name, + ) + ) + return self + + +# Modify this section ########################################################## + + +def _amin_amax_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "dim" not in kwargs: + # Supply an empty dim to match the aten signature + kwargs["dim"] = np.array([], dtype=np.int64) + else: + # Convert dim to a numpy array + kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64).reshape((-1,)) + return args, kwargs + + +def _avg_pool_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "dim" not in kwargs: + if len(args) > 6: + kwargs["divisor_override"] = args.pop(6) + if len(args) > 5: + kwargs["count_include_pad"] = args.pop(5) + if len(args) > 4: + kwargs["ceil_mode"] = args.pop(4) + if len(args) > 3: + padding = args.pop(3) + if isinstance(padding, np.ndarray): + # Cannot using list(padding) here, because the element will be numpy.int64 instead of int + padding = padding.tolist() + kwargs["padding"] = padding + if len(args) > 2: + stride = args.pop(2) + if isinstance(stride, np.ndarray): + stride = stride.tolist() + kwargs["stride"] = stride + kernel_size = args.pop(1) + if isinstance(kernel_size, np.ndarray): + kernel_size = kernel_size.tolist() + kwargs["kernel_size"] = kernel_size + return args, kwargs + + +def _cross_entropy_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "reduction" in kwargs: + reduction_vals = ["none", "mean", "sum"] + value = kwargs["reduction"] + idx = reduction_vals.index(value) + kwargs["reduction"] = idx + return args, kwargs + + +def _dropout_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "training" in kwargs: + kwargs["train"] = kwargs["training"] + kwargs.pop("training") + return args, kwargs + + +def _einsum_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Swap the equation and tensors to revert the special handling in the OpInfo + return [args[1], args[0]], kwargs + + +def _embedding_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + """Remove arguments not present in the aten op signature.""" + kwargs.pop("max_norm", None) + kwargs.pop("norm_type", None) + return args, kwargs + + +def _empty_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + """Remove arguments not present in the aten op signature.""" + kwargs.pop("requires_grad", None) + return args, kwargs + + +def _grid_sample_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Convert string attriute to int as input + inter_mode_options = {"bilinear": 0, "nearest": 1, "bicubic": 2} + padding_mode_options = {"zeros": 0, "border": 1, "reflection": 2} + args.append(inter_mode_options[kwargs["mode"]]) + args.append(padding_mode_options[kwargs["padding_mode"]]) + args.append(kwargs["align_corners"]) + kwargs.clear() + return args, kwargs + + +def _im2col_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Move kernel_size, dilation, padding and stride from args to kwargs + if len(args) == 5: + # Handle stride + stride = args.pop() + if isinstance(stride, np.ndarray): # convert stride to list[int] + stride = stride.tolist() + kwargs["stride"] = stride + # Handle padding + padding = args.pop() + if isinstance(padding, np.ndarray): # convert padding to list[int] + padding = padding.tolist() + kwargs["padding"] = padding + # Handle dilation + dilation = args.pop() + if isinstance(dilation, np.ndarray): # convert dilation to list[int] + dilation = dilation.tolist() + kwargs["dilation"] = dilation + # Handle kernel_size + kernel_size = args.pop() + if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int] + kernel_size = kernel_size.tolist() + kwargs["kernel_size"] = kernel_size + + return args, kwargs + + +def _index_put_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args[1] = [np.array(elem) for elem in args[1]] + return args, kwargs + + +def _max_pool_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Remove return_indices argument because this op doesn't accept it + kwargs.pop("return_indices", None) + return args, kwargs + + +def _mean_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Make the dims as tensor + if "dim" in kwargs: + kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64) + return args, kwargs + + +def _mse_loss_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "reduction" in kwargs: + reduction_vals = ["none", "mean", "sum"] # [0,1,2], default=1 + value = kwargs["reduction"] + idx = reduction_vals.index(value) + kwargs["reduction"] = idx + return args, kwargs + + +def _nll_loss_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if "reduction" in kwargs: + # aten_nll_loss can only accept integer argument instead of string + reduction_vals = ["none", "mean", "sum"] + value = kwargs["reduction"] + kwargs["reduction"] = reduction_vals.index(value) + return args, kwargs + + +def _nonzero_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + kwargs.pop("as_tuple", None) + return args, kwargs + + +def _reflection_pad2d_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args.pop(2) # remove 'reflect' arg + return args, kwargs + + +def _replication_pad2d_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args.pop(2) # remove 'replicate' arg + return args, kwargs + + +def _replication_pad3d_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args.pop(2) # remove 'replicate' arg + return args, kwargs + + +def _roll_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if len(args) >= 3: + if isinstance(args[2], np.ndarray): # convert dims to list[int] + # Change dims from args to kwargs to keep tuple/list type + dims = args.pop(2) + kwargs["dims"] = dims.tolist() + elif isinstance(args[2], int): # convert dims to list[int] + dims = args.pop(2) + kwargs["dims"] = [] + kwargs["dims"].append(dims) + if len(args) >= 2: + if isinstance(args[1], np.ndarray): # convert shift to list[int] + shifts = args.pop(1) + kwargs["shifts"] = shifts.tolist() + elif isinstance(args[1], int): + shifts = args.pop(1) + kwargs["shifts"] = [] + kwargs["shifts"].append(shifts) + return args, kwargs + + +def _scalar_tensor_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + kwargs.pop("requires_grad", None) + return args, kwargs + + +def _scatter_reduce_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Put the string into kwargs, otherwise FullGraph mode could not find get 'reduce' argument + kwargs["reduce"] = args.pop(4) + return args, kwargs + + +def _sum_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + if kwargs.get("dim") is not None: + kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64) + return args, kwargs + + +def _unflatten_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + args[1] = np.array(args[1], dtype=np.int64) + return args, kwargs + + +def _where_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # The aten::where op takes condition, x, y as inputs + # Swap the first two inputs + args[0], args[1] = args[1], args[0] + return args, kwargs + + +# Ops to be tested for numerical consistency between onnx and pytorch +# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py +TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = () + +ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) +ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) +ops_test_common.duplicate_opinfo( + OPS_DB, "arange", ("arange_start", "arange_start_step") +) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) +ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) +ops_test_common.duplicate_opinfo( + OPS_DB, + "bitwise_left_shift", + ( + "bitwise_left_shift_int8", + "bitwise_left_shift_int16", + "bitwise_left_shift_int32", + "bitwise_left_shift_int64", + ), +) +ops_test_common.duplicate_opinfo( + OPS_DB, + "bitwise_right_shift", + ( + "bitwise_right_shift_int8", + "bitwise_right_shift_int16", + "bitwise_right_shift_int32", + "bitwise_right_shift_int64", + ), +) +ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) +ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) +ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) +ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) +ops_test_common.duplicate_opinfo( + OPS_DB, + "nn.functional.pad", + ( + "nn.functional.reflection_pad2d", + "nn.functional.replication_pad2d", + "nn.functional.replication_pad3d", + ), +) +ops_test_common.duplicate_opinfo( + OPS_DB, + "nn.functional.scaled_dot_product_attention", + ("nn.functional.scaled_dot_product_attention_bool_mask",), +) +ops_test_common.duplicate_opinfo( + OPS_DB, + "nn.functional.celu", + ("nn.functional.celu_type_promoted",), +) +ops_test_common.duplicate_opinfo( + OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) +) +ops_test_common.duplicate_opinfo( + OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",) +) +ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) +ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) +ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) +ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) +ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",)) + +# MARK: End edits here + + +# These ops are not deterministic, so we check shape and dtype only +NONDETERMINISTIC_OPS: frozenset[str] = frozenset( + info.op_info_name for info in TESTED_TORCHLIB_OPS if info.nondeterministic +) + +COMPARE_SHAPE_ONLY_OPS: dict[ + str, + set, +] = { + info.op_info_name: set(info.compare_shape_only_for_output) + for info in TESTED_TORCHLIB_OPS +} + +TORCHLIB_OPINFO_MAPPING: dict[ + str, + TorchLibOpInfo, +] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if not info.complex} + +TESTED_OPS = frozenset(TORCHLIB_OPINFO_MAPPING) + +EXPECTED_SKIPS_OR_FAILS: tuple[ops_test_common.DecorateMeta, ...] = tuple( + functools.reduce( + # Flatten the list + lambda a, b: [*a, *b], + [ + [meta for meta in info.skips_or_fails if meta.matcher is None] + for info in TESTED_TORCHLIB_OPS + ], + ) +) + +SKIP_XFAIL_SUBTESTS: tuple[ops_test_common.DecorateMeta, ...] = tuple( + functools.reduce( + # Flatten the list + lambda a, b: [*a, *b], + [ + [meta for meta in info.skips_or_fails if meta.matcher is not None] + for info in TESTED_TORCHLIB_OPS + ], + ) +) + +# MARK: Complex supported functions +COMPLEX_FUNCTION_MAPPING: dict[ + str, + TorchLibOpInfo, +] = {info.op_info_name: info for info in TESTED_TORCHLIB_OPS if info.complex} + + +# Call dir(torch.ops.prims) and compare with entries in OPS_DB to create OpInfo for newly added prims ops +PRIMS_OPS_WITH_OP_INFO = ( + "abs", + "acos", + "acosh", + "add", + "amax", + "amin", + "as_strided", + "as_strided_scatter", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_not", + "bitwise_or", + "bitwise_xor", + "cat", + "ceil", + "clone", + "conj", + "conj_physical", + "cos", + "cosh", + "digamma", + "div", + "empty", + "eq", + "erf", + "erfc", + "exp", + "exp2", + "expm1", + "fill", + "floor", + "fmax", + "fmin", + "fmod", + "full", + "full_like", + "gcd", + "ge", + "gt", + "hypot", + "igamma", + "igammac", + "imag", + "isfinite", + "le", + "lgamma", + "log", + "log10", + "log1p", + "log2", + "lt", + "maximum", + "minimum", + "mul", + "ne", + "neg", + "nextafter", + "normal", + "pow", + "prod", + "real", + "reciprocal", + "remainder", + "reshape", + "round", + "rsqrt", + "scalar_tensor", + "sign", + "signbit", + "sin", + "sinh", + "sqrt", + "squeeze", + "sub", + "sum", + "svd", + "tan", + "tanh", + "transpose", + "trunc", + "uniform", + "where", +) + +for op in PRIMS_OPS_WITH_OP_INFO: + # Duplicate opinfo for prim ops. The new names all start with "prims_". E.g. "abs" -> "prims_abs". + ops_test_common.duplicate_opinfo_for_prims(OPS_DB, op) + +# Duplicate cases where the prims op name is different from the torch op name +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "i0", "bessel_i0") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j0", "bessel_j0") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.bessel_j1", "bessel_j1") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.erfcx", "erfcx") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i0e", "bessel_i0e") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1", "bessel_i1") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.i1e", "bessel_i1e") +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.ndtri", "ndtri") +ops_test_common.duplicate_opinfo_for_prims( + OPS_DB, "special.spherical_bessel_j0", "spherical_bessel_j0" +) +ops_test_common.duplicate_opinfo_for_prims(OPS_DB, "special.zeta", "zeta") + +OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS) +ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) +# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB +assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB" +assert NONDETERMINISTIC_OPS.issubset( + TESTED_OPS +), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS" From b06c470840aedbe0686dd5cab80d6ad3f3da8b5d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Feb 2025 11:50:31 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- test/onnx/torchlib/test_ops.py | 353 +++++++++++++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 test/onnx/torchlib/test_ops.py diff --git a/test/onnx/torchlib/test_ops.py b/test/onnx/torchlib/test_ops.py new file mode 100644 index 00000000000000..0033648d293267 --- /dev/null +++ b/test/onnx/torchlib/test_ops.py @@ -0,0 +1,353 @@ +# Owner(s): ["module: onnx"] +"""Test op correctness by comparing with PyTorch results. + +Usage: + + pytest test_ops.py + + To run tests on a specific operator (e.g. torch.ceil): + + pytest test_ops.py -k ceil + + To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention): + + pytest test_ops.py -k nn_functional_scaled_dot_product_attention + +## Environment variables + +1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults +in onnxruntime by running the inference sessions in a separate process. + +2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of +errors. +""" + +from __future__ import annotations + +import os +from typing import Callable, Optional, Sequence, Tuple, TYPE_CHECKING + +import error_reproduction +import numpy as np + +import onnx +import onnxruntime as ort +import onnxscript +import ops_test_common +import ops_test_data +import parameterized + +import torch +from torch.testing._internal import common_device_type, common_utils +from torch.utils import _pytree as pytree + + +if TYPE_CHECKING: + import unittest + + from torch.testing._internal.opinfo import core as opinfo_core + +# All dtypes will be tested on the generated symbolic functions. +# complex64 will be flattened to float32. +TESTED_DTYPES = ( + torch.float16, + torch.float32, + # Uncomment below item when we really need testing it + # torch.bfloat16, + # torch.float64, + torch.bool, + # torch.int8, + # torch.int16, + torch.int32, + torch.int64, + # torch.uint8, +) +# NOTE: torch.complex32 is experimental in torch +COMPLEX_TYPES = (torch.complex64,) + + +def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]: + """Returns all dtypes except the ones specified.""" + return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes) + + +def _should_skip_xfail_test_sample( + op_name: str, sample, dtype: torch.dtype, device_type: str +) -> Tuple[Optional[str], Optional[str]]: + """Returns a reason if a test sample should be skipped.""" + if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS: + return None, None + for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS: + # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small. + if decorator_meta.op_name == op_name: + assert decorator_meta.matcher is not None, "Matcher must be defined" + if not decorator_meta.enabled_if: + # Do not skip the test if the decorator meta is not enabled + continue + if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes: + # Not applicable for this dtype + continue + if ( + decorator_meta.device_type is not None + and decorator_meta.device_type != device_type + ): + # Not applicable for this device_type + continue + if decorator_meta.matcher(sample): + return decorator_meta.test_behavior, decorator_meta.reason + return None, None + + +class TestFunctionValidity(common_utils.TestCase): + @parameterized.parameterized.expand( + [ + (info.op.name, info) + for info in ops_test_data.TESTED_TORCHLIB_OPS + if isinstance(info.op, onnxscript.OnnxFunction) + ] + ) + def test_script_function_passes_checker( + self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo + ): + function_proto = torchlib_op_info.op.to_function_proto() + onnx.checker.check_function(function_proto) # type: ignore[attr-defined] + + +def run_test_output_match( + test_suite: unittest.TestCase, + device: str, + dtype: torch.dtype, + op: opinfo_core.OpInfo, + function_executor: Callable, + tested_op_mapping: dict[ + str, + ops_test_data.TorchLibOpInfo, + ], +): + """Base test method for testing each opset, used by instantiate_device_type_tests. + + Args: + test_suite: The test class instance. + device: The PyTorch device. instantiate_device_type_tests provides this. + dtype: The PyTorch dtype. instantiate_device_type_tests provides this. + op: The OpInfo instance. instantiate_device_type_tests provides this. + function_executor: The function executor. This is a function that takes + a function and its arguments and returns the output of the function. + tested_op_mapping: The mapping of op name to the tested op. + """ + samples = op.sample_inputs( + device, + dtype, + requires_grad=False, + ) + + torchlib_op_info = tested_op_mapping[op.name] + # Obtain the input_wrangler that manipulates the OpInfo inputs + # to match the aten operator signature + # An example is nn.functional.upsample_nearest2d, which has a different signature + # than the aten operator upsample_nearest2d + onnx_function = torchlib_op_info.op + input_wrangler = torchlib_op_info.input_wrangler + if ( + not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema) + and dtype not in COMPLEX_TYPES + ): + test_suite.skipTest( + f"dtype '{dtype}' is not supported by the op '{op.name}'. " + f"Type constraints: {onnx_function.op_schema.type_constraints}" + ) + + # Obtain the tolerance for the op + rtol, atol = torchlib_op_info.get_tolerance(dtype) + for i, cpu_sample in enumerate(samples): + inputs = (cpu_sample.input, *cpu_sample.args) + # Provide the repr to subtest because tensors are not serializable in parallel test runs + with test_suite.subTest( + sample_num=i, + inputs=repr( + [ + f"Tensor<{inp.shape}, dtype={inp.dtype}>" + if isinstance(inp, torch.Tensor) + else inp + for inp in inputs + ] + ), + kwargs=repr(cpu_sample.kwargs), + ): + try: + device_type = cpu_sample.args[0].device.type + except (AttributeError, IndexError): + device_type = "cpu" + test_behavior, reason = _should_skip_xfail_test_sample( + op.name, cpu_sample, dtype, device_type + ) + + with ops_test_common.normal_xfail_skip_test_behaviors( + test_behavior, reason + ): + input_onnx = [ + ops_test_common.convert_tensor_to_numpy(x) for x in inputs + ] + kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs) + if input_wrangler: + input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx) + torch_output = op(*inputs, **cpu_sample.kwargs) + + if isinstance(torch_output, torch.Tensor) and torch.is_complex( + torch_output + ): + torch_output = torch.view_as_real(torch_output.resolve_conj()) + + reference_torch_outputs, _ = pytree.tree_flatten(torch_output) + if ( + op.name.startswith("split") + or op.name.startswith("chunk") + or op.name.startswith("unbind") + or op.name + in { + "atleast_1d_Sequence", + "atleast_2d_Sequence", + "atleast_3d_Sequence", + } + ): + # Hack for handling split, chunk and unbind which relies on SplitToSequence op. + # Split returns a Sequence that should be treats as a single + # value. So we wrap it into a tuple. + # TODO(justinchuby): Find a more general solution + reference_torch_outputs = [reference_torch_outputs] + + test_name = test_suite.id() + function_output, model_proto = function_executor( + test_name, reference_torch_outputs + )(onnx_function, input_onnx, kwargs_onnx) + # Finally we re-flatten everything + # TODO: add pytree structure comparison. + flattened_torch_outputs, _ = pytree.tree_flatten(torch_output) + flattened_function_outputs, _ = pytree.tree_flatten(function_output) + + assert flattened_torch_outputs + assert len(flattened_torch_outputs) == len(flattened_function_outputs) + + for j, (torch_output, function_output) in enumerate( + zip(flattened_torch_outputs, flattened_function_outputs) + ): + actual = torch.tensor(function_output) + expected = ( + torch_output + if isinstance(torch_output, torch.Tensor) + else torch.tensor(torch_output) + ) + + if ( + op.name in ops_test_data.NONDETERMINISTIC_OPS + or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name] + ): + # Check shape and dtype only for ops that are known to be + # nondeterministic + test_suite.assertEqual(actual.shape, expected.shape) + test_suite.assertEqual(actual.dtype, expected.dtype) + continue + + # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match + try: + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=True, + check_device=False, + ) + except AssertionError as e: + if ( + os.environ.get("CREATE_REPRODUCTION_REPORT") == "1" + and test_behavior is None + ): + error_reproduction.create_mismatch_report( + test_name, + i, + model_proto, + inputs, + cpu_sample.kwargs, + actual, + expected, + e, + __file__, + ) + if len(flattened_torch_outputs) > 1: + raise AssertionError(f"Output {j} mismatch") from e + raise + + +class TestOutputConsistencyFullGraph(common_utils.TestCase): + """Test output consistency between exported ONNX op run as a graph and PyTorch eager mode. + + This is a parameterized test suite. + """ + + def setUp(self) -> None: + torch.manual_seed(42) + np.random.seed(42) + ort.set_seed(42) + + @ops_test_common.add_decorate_info( + ops_test_data.OPS_DB, + "TestOutputConsistencyFullGraph", + "test_output_match_opinfo_", + skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS, + ) + @common_device_type.ops( # type: ignore[misc] + [ + info + for info in ops_test_data.OPS_DB + if info.name in ops_test_data.TESTED_OPS + ], + allowed_dtypes=TESTED_DTYPES, + ) + def test_output_match_opinfo_( + self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo + ): + # Base test method for testing each op by running the full ONNX graph. + run_test_output_match( + self, + device, + dtype, + op, + ops_test_common.graph_executor, + ops_test_data.TORCHLIB_OPINFO_MAPPING, + ) + + @ops_test_common.add_decorate_info( + ops_test_data.OPS_DB, + "TestOutputConsistencyFullGraph", + "test_complex_output_match_opinfo_", + skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS, + ) + @common_device_type.ops( # type: ignore[misc] + [ + info + for info in ops_test_data.OPS_DB + if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING + ], + allowed_dtypes=COMPLEX_TYPES, + ) + def test_complex_output_match_opinfo_( + self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo + ): + """Base test method for testing each op by running the full ONNX graph.""" + run_test_output_match( + self, + device, + dtype, + op, + ops_test_common.graph_executor, + ops_test_data.COMPLEX_FUNCTION_MAPPING, + ) + + +common_device_type.instantiate_device_type_tests( + TestOutputConsistencyFullGraph, globals(), only_for=["cpu"] +) + +if __name__ == "__main__": + common_utils.run_tests() From fd457ff9095334d216df91ff9bfb3df7a66692e8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 18 Feb 2025 19:44:50 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- test/onnx/torchlib/test_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/onnx/torchlib/test_ops.py b/test/onnx/torchlib/test_ops.py index 0033648d293267..13ccae6185efb9 100644 --- a/test/onnx/torchlib/test_ops.py +++ b/test/onnx/torchlib/test_ops.py @@ -104,7 +104,8 @@ class TestFunctionValidity(common_utils.TestCase): (info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS if isinstance(info.op, onnxscript.OnnxFunction) - ] + ], + skip_on_empty=True, ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo