8000 Feat: added opset 21, 22 and 23 · pytorch/pytorch@5145626 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5145626

Browse files
committed
Feat: added opset 21, 22 and 23
1 parent e5e06d9 commit 5145626

File tree

7 files changed

+259
-5
lines changed

7 files changed

+259
-5
lines changed

torch/csrc/jit/passes/onnx/helper.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ Node* createONNXUnsqueeze(
169169
Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
170170
unsqueeze_node->addInput(input);
171171
unsqueeze_node->insertBefore(n_to_insert_before);
172-
if (opset_version >= OPSET_VERSION_13) {
172+
if (opset_version >= OPSET_VERSION_13)
173+
{
173174
// ONNX spec sets `axes` as input for opset >= 13.
174175
Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
175176
unsqueeze_axes->insertBefore(unsqueeze_node);

torch/onnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"symbolic_opset18",
2525
"symbolic_opset19",
2626
"symbolic_opset20",
27+
"symbolic_opset21",
2728
# Enums
2829
"OperatorExportTypes",
2930
"TrainingMode",
@@ -90,6 +91,7 @@
9091
symbolic_opset18,
9192
symbolic_opset19,
9293
symbolic_opset20,
94+
symbolic_opset21,
9395
utils,
9496
)
9597

torch/onnx/_constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
ONNX_BASE_OPSET = 9
66
ONNX_MIN_OPSET = 7
7-
ONNX_MAX_OPSET = 20
8-
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
7+
ONNX_MAX_OPSET = 23
8+
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 23
99
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
1010
ONNX_DEFAULT_OPSET = 17
1111
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

torch/onnx/symbolic_opset21.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# mypy: allow-untyped-defs
2+
"""This file exports ONNX ops for opset 20.
3+
4+
Note [ONNX Operators that are added/updated in opset 20]
5+
6+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
8+
New operators:
9+
AffineGrid
10+
ConstantOfShape
11+
DFT
12+
Gelu
13+
GridSample
14+
ImageDecoder
15+
IsInf
16+
IsNaN
17+
ReduceMax
18+
ReduceMin
19+
RegexFullMatch
20+
StringConcat
21+
StringSplit
22+
"""
23+
24+
import functools
25+
26+
import torch.nn.functional as F
27+
from torch import _C
28+
from torch.onnx import symbolic_helper
29+
from torch.onnx._internal import jit_utils, registration
30+
31+
32+
# EDITING THIS FILE? READ THIS FIRST!
33+
# see Note [Edit Symbolic Files] in symbolic_helper.py
34+
35+
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
36+
37+
38+
def convert_grid_sample_mode(mode_s):
39+
return (
40+
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
41+
)
42+
43+
44+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=21)
45+
46+
47+
@_onnx_symbolic("aten::grid_sampler")
48+
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
49+
def _grid_sampler(
50+
g: jit_utils.GraphContext,
51+
input: _C.Value,
52+
grid: _C.Value,
53+
mode_enum: int,
54+
padding_mode_enum: int,
55+
align_corners: bool,
56+
):
57+
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
58+
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
59+
mode_s = convert_grid_sample_mode(mode_s)
60+
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
61+
padding_mode_enum # type: ignore[index]
62+
]
63+
return g.op(
64+
"GridSample",
65+
input,
66+
grid,
67+
align_corners_i=int(align_corners),
68+
mode_s=mode_s,
69+
padding_mode_s=padding_mode_s,
70+
)
71+
72+
73+
# @_onnx_symbolic("aten::affine_grid_generator")
74+
# @symbolic_helper.parse_args("v", "v", "b")
75+
# def _affine_grid_generator(
76+
# g: jit_utils.GraphContext,
77+
# theta: _C.Value,
78+
# size: _C.Value,
79+
# align_corners: bool,
80+
# ):
81+
# return g.op(
82+
# "AffineGrid",
83+
# theta,
84+
# size,
85+
# align_corners_i=int(align_corners),
86+
# )
87+
88+
89+
# @_onnx_symbolic("aten::gelu")
90+
# @symbolic_helper.parse_args("v", "s")
91+
# def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
92+
# return g.op("Gelu", self, approximate_s=approximate)

torch/onnx/symbolic_opset22.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# mypy: allow-untyped-defs
2+
"""This file exports ONNX ops for opset 20.
3+
4+
Note [ONNX Operators that are added/updated in opset 20]
5+
6+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
8+
New operators:
9+
"""
10+
11+
import functools
12+
13+
import torch.nn.functional as F
14+
from torch import _C
15+
from torch.onnx import symbolic_helper
16+
from torch.onnx._internal import jit_utils, registration
17+
18+
19+
# EDITING THIS FILE? READ THIS FIRST!
20+
# see Note [Edit Symbolic Files] in symbolic_helper.py
21+
22+
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
23+
24+
25+
def convert_grid_sample_mode(mode_s):
26+
return (
27+
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
28+
)
29+
30+
31+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=21)
32+
33+
34+
@_onnx_symbolic("aten::grid_sampler")
35+
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
36+
def _grid_sampler(
37+
g: jit_utils.GraphContext,
38+
input: _C.Value,
39+
grid: _C.Value,
40+
mode_enum: int,
41+
padding_mode_enum: int,
42+
align_corners: bool,
43+
):
44+
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
45+
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
46+
mode_s = convert_grid_sample_mode(mode_s)
47+
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
48+
padding_mode_enum # type: ignore[index]
49+
]
50+
return g.op(
51+
"GridSample",
52+
input,
53+
grid,
54+
align_corners_i=int(align_corners),
55+
mode_s=mode_s,
56+
padding_mode_s=padding_mode_s,
57+
)
58+
59+
60+
# @_onnx_symbolic("aten::affine_grid_generator")
61+
# @symbolic_helper.parse_args("v", "v", "b")
62+
# def _affine_grid_generator(
63+
# g: jit_utils.GraphContext,
64+
# theta: _C.Value,
65+
# size: _C.Value,
66+
# align_corners: bool,
67+
# ):
68+
# return g.op(
69+
# "AffineGrid",
70+
# theta,
71+
# size,
72+
# align_corners_i=int(align_corners),
73+
# )
74+
75+
76+
# @_onnx_symbolic("aten::gelu")
77+
# @symbolic_helper.parse_args("v", "s")
78+
# def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
79+
# return g.op("Gelu", self, approximate_s=approximate)

torch/onnx/symbolic_opset23.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# mypy: allow-untyped-defs
2+
"""This file exports ONNX ops for opset 23.
3+
4+
Note [ONNX Operators that are added/updated in opset 23]
5+
6+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
8+
New operators:
9+
Attention
10+
"""
11+
12+
import functools
13+
14+
import torch.nn.functional as F
15+
from torch import _C
16+
from torch.onnx import symbolic_helper
17+
from torch.onnx._internal import jit_utils, registration
18+
19+
20+
# EDITING THIS FILE? READ THIS FIRST!
21+
# see Note [Edit Symbolic Files] in symbolic_helper.py
22+
23+
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
24+
25+
26+
def convert_grid_sample_mode(mode_s):
27+
return (
28+
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
29+
)
30+
31+
32+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=21)
33+
34+
35+
@_onnx_symbolic("aten::grid_sampler")
36+
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
37+
def _grid_sampler(
38+
g: jit_utils.GraphContext,
39+
input: _C.Value,
40+
grid: _C.Value,
41+
mode_enum: int,
42+
padding_mode_enum: int,
43+
align_corners: bool,
44+
):
45+
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
46+
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
47+
mode_s = convert_grid_sample_mode(mode_s)
48+
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
49+
padding_mode_enum # type: ignore[index]
50+
]
51+
return g.op(
52+
"GridSample",
53+
input,
54+
grid,
55+
align_corners_i=int(align_corners),
56+
mode_s=mode_s,
57+
padding_mode_s=padding_mode_s,
58+
)
59+
60+
61+
# @_onnx_symbolic("aten::affine_grid_generator")
62+
# @symbolic_helper.parse_args("v", "v", "b")
63+
# def _affine_grid_generator(
64+
# g: jit_utils.GraphContext,
65+
# theta: _C.Value,
66+
# size: _C.Value,
67+
# align_corners: bool,
68+
# ):
69+
# return g.op(
70+
# "AffineGrid",
71+
# theta,
72+
# size,
73+
# align_corners_i=int(align_corners),
74+
# )
75+
76+
77+
# @_onnx_symbolic("aten::gelu")
78+
# @symbolic_helper.parse_args("v", "s")
79+
# def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
80+
# return g.op("Gelu", self, approximate_s=approximate)

torch/onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def export(
353353
354354
Models exported this way are probably runnable only by Caffe2.
355355
356-
opset_version (int, default 17): The version of the
356+
opset_version (int, default 21): The version of the
357357
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
358358
to target. Must be >= 7 and <= 17.
359359
do_constant_folding: Apply the constant-folding optimization.
@@ -1393,7 +1393,7 @@ def _export(
13931393
if opset_version is None:
13941394
opset_version = _constants.ONNX_DEFAULT_OPSET
13951395

1396-
# torch.onnx.export does not support opset versions >=18
1396+
# torch.onnx.export does not support opset versions >=21
13971397
if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
13981398
# We do not want to fail because we should still allow users to create
13991399
# custom symbolic functions for opset>17

0 commit comments

Comments
 (0)
0