8000 [ONNX] Create scaffolding for torchlib ops by justinchuby · Pull Request #147401 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Create scaffolding for torchlib ops #147401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Update
[ghstack-poisoned]
  • Loading branch information
justinchuby committed Feb 18, 2025
commit ef3975c98a7a3b42735ce18b5c2280de410ab6af
14 changes: 10 additions & 4 deletions test/onnx/torchlib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import ops_test_common

import torch
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
from torch.testing._internal import common_methods_invocations
from torch.testing._internal.opinfo import definitions as opinfo_definitions

Expand Down Expand Up @@ -441,7 +442,12 @@ def _where_input_wrangler(

# Ops to be tested for numerical consistency between onnx and pytorch
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ()
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
TorchLibOpInfo("abs", core_ops.aten_abs),
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
Expand Down Expand Up @@ -680,6 +686,6 @@ def _where_input_wrangler(
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
assert NONDETERMINISTIC_OPS.issubset(
TESTED_OPS
), f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS"
assert NONDETERMINISTIC_OPS.issubset(TESTED_OPS), (
f"{NONDETERMINISTIC_OPS - TESTED_OPS} not in TESTED_OPS"
)
78 changes: 78 additions & 0 deletions torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""Typings for function definitions."""

from __future__ import annotations

from typing import TypeVar, Union

from onnxscript import (
BFLOAT16,
BOOL,
COMPLEX128,
COMPLEX64,
DOUBLE,
FLOAT,
FLOAT16,
INT16,
INT32,
INT64,
INT8,
STRING,
UINT8,
)


# NOTE: We do not care about unsigned types beyond UINT8 because PyTorch does not us them.
# More detail can be found: https://pytorch.org/docs/stable/tensors.html

_TensorType = Union[
BFLOAT16,
BOOL,
COMPLEX64,
COMPLEX128,
DOUBLE,
FLOAT,
FLOAT16,
INT8,
INT16,
INT32,
INT64,
UINT8,
]
_FloatType = Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16]
IntType = Union[INT8, INT16, INT32, INT64]
RealType = Union[
BFLOAT16,
FLOAT16,
FLOAT,
DOUBLE,
INT8,
INT16,
INT32,
INT64,
]

TTensor = TypeVar("TTensor", bound=_TensorType)
# Duplicate TTensor for inputs/outputs that accept the same set of types as TTensor
# but do not constrain the type to be the same as the other inputs/outputs
TTensor2 = TypeVar("TTensor2", bound=_TensorType)
TTensorOrString = TypeVar("TTensorOrString", bound=Union[_TensorType, STRING])
TFloat = TypeVar("TFloat", bound=_FloatType)
TFloatOrUInt8 = TypeVar(
"TFloatOrUInt8", bound=Union[FLOAT, FLOAT16, DOUBLE, INT8, UINT8]
)
TInt = TypeVar("TInt", bound=IntType)
TReal = TypeVar("TReal", bound=RealType)
TRealUnlessInt16OrInt8 = TypeVar(
"TRealUnlessInt16OrInt8",
bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64],
)
TRealUnlessFloat16OrInt8 = TypeVar(
"TRealUnlessFloat16OrInt8", bound=Union[DOUBLE, FLOAT, INT16, INT32, INT64]
)
TRealOrUInt8 = TypeVar("TRealOrUInt8", bound=Union[RealType, UINT8])
TFloatHighPrecision = TypeVar("TFloatHighPrecision", bound=Union[FLOAT, DOUBLE])
47 changes: 47 additions & 0 deletions torch/onnx/_internal/exporter/_torchlib/ops/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""torch.ops.aten operators under the `core` module."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
# ruff: noqa: TCH001,TCH002
# flake8: noqa

from __future__ import annotations

import operator

from onnxscript.onnx_opset import opset18 as op

import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal, TRealOrUInt8
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl


aten = torch.ops.aten


@onnx_impl((aten.abs.default, operator.abs), trace_only=True)
def aten_abs(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""

return op.Abs(self)


@onnx_impl(aten.abs.default, complex=True, trace_only=True)
def aten_abs_complex(self: TRealOrUInt8) -> TRealOrUInt8:
"""abs(Tensor self) -> Tensor"""

return op.ReduceL2(self, [-1], keepdims=False)


@onnx_impl((aten.add.Tensor, aten.add.Scalar, operator.add), trace_only=True)
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
return op.Add(self, other)


@onnx_impl((aten.add.Tensor, aten.add.Scalar), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

return aten_add(self, other, alpha=alpha)
Loading
0