8000 Feat: added opset 21, 22 and 23 · pytorch/pytorch@6a7d1c5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a7d1c5

Browse files
committed
Feat: added opset 21, 22 and 23
1 parent d87fb8e commit 6a7d1c5

File tree

3 files changed

+137
-81
lines changed

3 files changed

+137
-81
lines changed

torch/onnx/symbolic_opset21.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# mypy: allow-untyped-defs
2-
"""This file exports ONNX ops for opset 20.
2+
"""This file exports ONNX ops for opset 21.
33
4-
Note [ONNX Operators that are added/updated in opset 20]
4+
Note [ONNX Operators that are added/updated in opset 21]
55
66
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7-
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
7+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-21-of-the-default-onnx-operator-set
88
New operators:
9-
AffineGrid
9+
- Cosh
10+
- Erf
11+
- GridSample (updated)
12+
- Trunc
1013
"""
1114

1215
import functools
@@ -16,19 +19,23 @@
1619
from torch.onnx import symbolic_helper
1720
from torch.onnx._internal import jit_utils, registration
1821

19-
2022
# EDITING THIS FILE? READ THIS FIRST!
2123
# see Note [Edit Symbolic Files] in symbolic_helper.py
2224

23-
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
24-
25+
__all__ = [
26+
"_grid_sampler",
27+
"_affine_grid_generator",
28+
"gelu",
29+
"cosh",
30+
"erf",
31+
"trunc"
32+
]
2533

2634
def convert_grid_sample_mode(mode_s):
2735
return (
2836
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
2937
)
3038

31-
3239
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=21)
3340

3441

@@ -42,11 +49,12 @@ def _grid_sampler(
4249
padding_mode_enum: int,
4350
align_corners: bool,
4451
):
45-
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
46-
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
52+
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[
53+
mode_enum
54+
]
4755
mode_s = convert_grid_sample_mode(mode_s)
48-
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
49-
padding_mode_enum # type: ignore[index]
56+
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[
57+
padding_mode_enum
5058
]
5159
return g.op(
5260
"GridSample",
@@ -58,3 +66,41 @@ def _grid_sampler(
5866
)
5967

6068

69+
@_onnx_symbolic("aten::affine_grid_generator")
70+
@symbolic_helper.parse_args("v", "v", "b")
71+
def _affine_grid_generator(
72+
g: jit_utils.GraphContext,
73+
theta: _C.Value,
74+
size: _C.Value,
75+
align_corners: bool,
76+
):
77+
return g.op(
78+
"AffineGrid",
79+
theta,
80+
size,
81+
align_corners_i=int(align_corners),
82+
)
83+
84+
85+
@_onnx_symbolic("aten::gelu")
86+
@symbolic_helper.parse_args("v", "s")
87+
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
88+
return g.op("Gelu", self, approximate_s=approximate)
89+
90+
91+
@_onnx_symbolic("aten::cosh")
92+
@symbolic_helper.parse_args("v")
93+
def cosh(g: jit_utils.GraphContext, self: _C.Value):
94+
return g.op("Cosh", self)
95+
96+
97+
@_onnx_symbolic("aten::erf")
98+
@symbolic_helper.parse_args("v")
99+
def erf(g: jit_utils.GraphContext, self: _C.Value):
100+
return g.op("Erf", self)
101+
102+
103+
@_onnx_symbolic("aten::trunc")
104+
@symbolic_helper.parse_args("v")
105+
def trunc(g: jit_utils.GraphContext, self: _C.Value):
106+
return g.op("Trunc", self)

torch/onnx/symbolic_opset22.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,69 @@
66
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-22-of-the-default-onnx-operator-set
88
New operators:
9+
- DFT
10+
- IDFT
11+
- HammingWindow
12+
- HannWindow
13+
- BlackmanWindow
914
"""
1015

1116
import functools
12-
13-
import torch.nn.functional as F
17+
import torch
1418
from torch import _C
1519
from torch.onnx import symbolic_helper
1620
from torch.onnx._internal import jit_utils, registration
1721

22+
__all__ = [
23+
"dft",
24+
"idft",
25+
"hamming_window",
26+
"hann_window",
27+
"blackman_window",
28+
]
1829

19-
# EDITING THIS FILE? READ THIS FIRST!
20-
# see Note [Edit Symbolic Files] in symbolic_helper.py
30+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=22)
2131

22-
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
2332

33+
@_onnx_symbolic("aten::fft_fft")
34+
@symbolic_helper.parse_args("v", "i", "s", "i")
35+
def dft(g, input, dim, norm, lastdim):
36+
return g.op("DFT", input, axis_i=dim, inverse_i=0, norm_s=norm)
2437

25-
def convert_grid_sample_mode(mode_s):
26-
return (
27-
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
28-
)
2938

39+
@_onnx_symbolic("aten::fft_ifft")
40+
@symbolic_helper.parse_args("v", "i", "s", "i")
41+
def idft(g, input, dim, norm, lastdim):
42+
return g.op("DFT", input, axis_i=dim, inverse_i=1, norm_s=norm)
3043

31-
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=22)
44+
45+
@_onnx_symbolic("aten::hamming_window")
46+
@symbolic_helper.parse_args("i", "b", "f", "i")
47+
def hamming_window(g, window_length, periodic, alpha, dtype):
48+
return g.op(
49+
"HammingWindow",
50+
g.op("Constant", value_t=torch.tensor(window_length)),
51+
alpha_f=alpha,
52+
periodic_i=int(periodic)
53+
)
3254

3355

34-
@_onnx_symbolic("aten::grid_sampler")
35-
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
36-
def _grid_sampler(
37-
g: jit_utils.GraphContext,
38-
input: _C.Value,
39-
grid: _C.Value,
40-
mode_enum: int,
41-
padding_mode_enum: int,
42-
align_corners: bool,
43-
):
44-
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
45-
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
46-
mode_s = convert_grid_sample_mode(mode_s)
47-
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
48-
padding_mode_enum # type: ignore[index]
49-
]
56+
@_onnx_symbolic("aten::hann_window")
57+
@symbolic_helper.parse_args("i", "b", "i")
58+
def hann_window(g, window_length, periodic, dtype):
5059
return g.op(
51-
"GridSample",
52-
input,
53-
grid,
54-
align_corners_i=int(align_corners),
55-
mode_s=mode_s,
56-
padding_mode_s=padding_mode_s,
60+
"HannWindow",
61+
g.op("Constant", value_t=torch.tensor(window_length)),
62+
periodic_i=int(periodic)
5763
)
5864

5965

66+
@_onnx_symbolic("aten::blackman_window")
67+
@symbolic_helper.parse_args("i", "b", "f", "i")
68+
def blackman_window(g, window_length, periodic, beta, dtype):
69+
return g.op(
70+
"BlackmanWindow",
71+
g.op("Constant", value_t=torch.tensor(window_length)),
72+
beta_f=beta,
73+
periodic_i=int(periodic)
74+
)

torch/onnx/symbolic_opset23.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,49 @@
66
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-23-of-the-default-onnx-operator-set
88
New operators:
9-
Attention
9+
- Attention
1010
"""
1111

1212
import functools
13-
14-
import torch.nn.functional as F
1513
from torch import _C
1614
from torch.onnx import symbolic_helper
1715
from torch.onnx._internal import jit_utils, registration
1816

19-
20-
# EDITING THIS FILE? READ THIS FIRST!
21-
# see Note [Edit Symbolic Files] in symbolic_helper.py
22-
23-
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
24-
25-
26-
def convert_grid_sample_mode(mode_s):
27-
return (
28-
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
29-
)
30-
17+
__all__ = ["attention"]
3118

3219
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=23)
3320

34-
35-
@_onnx_symbolic("aten 1CF5 ::grid_sampler")
36-
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
37-
def _grid_sampler(
21+
@_onnx_symbolic("aten::attention")
22+
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "v", "v", "v", "v", "v", "v", "v", "v")
23+
def attention(
3824
g: jit_utils.GraphContext,
39-
input: _C.Value,
40-
grid: _C.Value,
41-
mode_enum: int,
42-
padding_mode_enum: int,
43-
align_corners: bool,
25+
query: _C.Value,
26+
key: _C.Value,
27+
value: _C.Value,
28+
bias: _C.Value,
29+
mask_index: _C.Value,
30+
past_key: _C.Value,
31+
past_value: _C.Value,
32+
static_kv: _C.Value,
33+
use_past: _C.Value,
34+
unidirectional: _C.Value,
35+
num_heads: _C.Value,
36+
scale: _C.Value,
37+
dropout: _C.Value,
4438
):
45-
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
46-
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
47-
mode_s = convert_grid_sample_mode(mode_s)
48-
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
49-
padding_mode_enum # type: ignore[index]
50-
]
5139
return g.op(
52-
"GridSample",
53-
input,
54-
grid,
55-
align_corners_i=int(align_corners),
56-
mode_s=mode_s,
57-
padding_mode_s=padding_mode_s,
40+
"Attention",
41+
query,
42+
key,
43+
value,
44+
bias,
45+
mask_index,
46+
past_key,
47+
past_value,
48+
static_kv,
49+
use_past,
50+
unidirectional,
51+
num_heads,
52+
scale,
53+
dropout,
5854
)
59-

0 commit comments

Comments
 (0)
0