8000 Updated ONNX Opset Version to Support Attention Operator #153611 by hamzaqureshi5 · Pull Request #153687 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Updated ONNX Opset Version to Support Attention Operator #153611 #153687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension 8000

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/googletest
Submodule googletest updated 61 files
+13 −27 BUILD.bazel
+2 −2 CMakeLists.txt
+21 −28 MODULE.bazel
+12 −3 README.md
+10 −39 WORKSPACE
+29 −47 ci/linux-presubmit.sh
+4 −7 ci/macos-presubmit.sh
+13 −25 ci/windows-presubmit.bat
+50 −70 docs/advanced.md
+13 −0 docs/faq.md
+20 −53 docs/gmock_cook_book.md
+2 −2 docs/primer.md
+30 −23 docs/quickstart-bazel.md
+1 −2 docs/reference/actions.md
+2 −9 docs/reference/assertions.md
+4 −7 docs/reference/matchers.md
+1 −98 docs/reference/testing.md
+4 −32 fake_fuchsia_sdk.bzl
+18 −101 googlemock/include/gmock/gmock-actions.h
+76 −343 googlemock/include/gmock/gmock-matchers.h
+8 −8 googlemock/include/gmock/gmock-more-actions.h
+6 −4 googlemock/include/gmock/gmock-spec-builders.h
+5 −0 googlemock/include/gmock/internal/gmock-internal-utils.h
+0 −1 googlemock/include/gmock/internal/gmock-port.h
+4 −4 googlemock/src/gmock-cardinalities.cc
+7 −71 googlemock/test/gmock-actions_test.cc
+2 −2 googlemock/test/gmock-function-mocker_test.cc
+27 −210 googlemock/test/gmock-matchers-arithmetic_test.cc
+5 −131 googlemock/test/gmock-matchers-comparisons_test.cc
+19 −317 googlemock/test/gmock-matchers-containers_test.cc
+23 −82 googlemock/test/gmock-matchers-misc_test.cc
+10 −39 googlemock/test/gmock-more-actions_test.cc
+1 −1 googlemock/test/gmock-pp_test.cc
+4 −3 googlemock/test/gmock-spec-builders_test.cc
+1 −1 googlemock/test/gmock_link_test.h
+3 −3 googletest/README.md
+1 −1 googletest/cmake/internal_utils.cmake
+0 −7 googletest/include/gtest/gtest-assertion-result.h
+3 −3 googletest/include/gtest/gtest-matchers.h
+39 −95 googletest/include/gtest/gtest-param-test.h
+0 −39 googletest/include/gtest/gtest-printers.h
+65 −61 googletest/include/gtest/gtest-typed-test.h
+2 −2 googletest/include/gtest/gtest.h
+44 −40 googletest/include/gtest/internal/gtest-internal.h
+14 −48 googletest/include/gtest/internal/gtest-param-util.h
+66 −16 googletest/include/gtest/internal/gtest-port.h
+0 −4 googletest/src/gtest-internal-inl.h
+34 −157 googletest/src/gtest.cc
+17 −51 googletest/test/BUILD.bazel
+0 −38 googletest/test/googletest-fail-if-no-test-linked-test-with-disabled-test_.cc
+0 −38 googletest/test/googletest-fail-if-no-test-linked-test-with-enabled-test_.cc
+0 −169 googletest/test/googletest-fail-if-no-test-linked-test.py
+0 −19 googletest/test/googletest-filter-unittest.py
+16 −95 googletest/test/googletest-json-output-unittest.py
+0 −70 googletest/test/googletest-param-test-test.cc
+1 −31 googletest/test/googletest-printers-test.cc
+2 −2 googletest/test/googletest-setuptestsuite-test_.cc
+4 −52 googletest/test/gtest_unittest.cc
+26 −67 googletest/test/gtest_xml_output_unittest.py
+1 −21 googletest/test/gtest_xml_output_unittest_.cc
+7 −7 googletest_deps.bzl
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/onnx/helper.cpp
10000
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ Node* createONNXUnsqueeze(
Node* unsqueeze_node = graph->create(onnx::Unsqueeze, 1);
unsqueeze_node->addInput(input);
unsqueeze_node->insertBefore(n_to_insert_before);
if (opset_version >= OPSET_VERSION_13) {
if (opset_version >= OPSET_VERSION_13)
{
// ONNX spec sets `axes` as input for opset >= 13.
Node* unsqueeze_axes = graph->create(onnx::Constant, 1);
unsqueeze_axes->insertBefore(unsqueeze_node);
Expand Down
4 changes: 4 additions & 0 deletions torch/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"symbolic_opset18",
"symbolic_opset19",
"symbolic_opset20",
"symbolic_opset21",
# Enums
"OperatorExportTypes",
"TrainingMode",
Expand Down Expand Up @@ -90,6 +91,9 @@
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
symbolic_opset21,
symbolic_opset22,
symbolic_opset23,
utils,
)

Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/_constants.py
F803
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 20
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
ONNX_MAX_OPSET = 23
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 23
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 17
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
Expand Down
37 changes: 37 additions & 0 deletions torch/onnx/symbolic_opset21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 21.

Note [ONNX Operators that are added/updated in opset 21]

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-21-of-the-default-onnx-operator-set
New operators:
- Gelu
"""

import functools

import torch.nn.functional as F
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration

# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py

__all__ = [
"gelu",
]

def convert_grid_sample_mode(mode_s):
return (
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
)

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=21)


@_onnx_symbolic("aten::gelu")
@symbolic_helper.parse_args("v", "s")
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
return g.op("Gelu", self, approximate_s=approximate)
31 changes: 31 additions & 0 deletions torch/onnx/symbolic_opset22.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 22.

Note [ONNX Operators that are added/updated in opset 22]

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-22-of-the-default-onnx-operator-set
New operators:
Selu
"""

import functools
import torch
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration

__all__ = [
"selu",
]

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=22)

@_onnx_symbolic("aten::selu")
@symbolic_helper.parse_args("v")
def selu(g: jit_utils.GraphContext, self: _C.Value):
# Use default alpha and gamma values as per ONNX Selu-22 spec
alpha = 1.6732632
gamma = 1.0507
return g.op("Selu", self, alpha_f=alpha, gamma_f=gamma)

93 changes: 93 additions & 0 deletions torch/onnx/symbolic_opset23.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# mypy: allow-untyped-defs
"""This file exports ONNX ops for opset 23.

Note [ONNX Operators that are added/updated in opset 23]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-23-of-the-default-onnx-operator-set
New operators:
Attention
RMSNormalization
Reshape
RotaryEmbedding
"""

import functools
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import jit_utils, registration

__all__ = ["attention", "rms_normalization", "reshape", "rotary_embedding"]

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=23)


@_onnx_symbolic("aten::attention")
@symbolic_helper.parse_args("v", "v", "v", "v", "v", "v", "i", "i", "i", "f", "f", "i")
def attention(
g: jit_utils.GraphContext,
q: _C.Value,
k: _C.Value,
v: _C.Value,
attn_mask: _C.Value,
past_key: _C.Value,
past_value: _C.Value,
q_num_heads: int,
kv_num_heads: int,
qk_matmul_output_mode: int,
scale: float,
softcap: float,
softmax_precision: int,
):
inputs = [q, k, v]
if attn_mask.node().kind() != "prim::Constant" or attn_mask.type().kind() != "NoneType":
inputs.append(attn_mask)
if past_key.node().kind() != "prim::Constant" or past_key.type().kind() != "NoneType":
inputs.append(past_key)
if past_value.node().kind() != "prim::Constant" or past_value.type().kind() != "NoneType":
inputs.append(past_value)

return g.op(
"Attention",
*inputs,
q_num_heads_i=q_num_heads,
kv_num_heads_i=kv_num_heads,
qk_matmul_output_mode_i=qk_matmul_output_mode,
scale_f=scale,
softcap_f=softcap,
softmax_precision_i=softmax_precision,
outputs=4,
)


@_onnx_symbolic("aten::rms_norm")
@symbolic_helper.parse_args("v", "v", "i", "f")
def rms_normalization(g, input, scale, axis, epsilon):
squared = g.op("Mul", input, input)
rank = symbolic_helper._get_tensor_rank(input)
axes = list(range(axis if axis >= 0 else rank + axis, rank)) # type: ignore
mean = g.op("ReduceMean", squared, axes_i=axes, keepdims_i=1)
mean_eps = g.op("Add", mean, g.op("Constant", value_t=symbolic_helper._scalar(epsilon)))
rms = g.op("Sqrt", mean_eps)
normalized = g.op("Div", input, rms)
return g.op("Mul", normalized, scale)


@_onnx_symbolic("aten::reshape")
@symbolic_helper.parse_args("v", "v")
def reshape(g, input, shape):
return g.op("Reshape", input, shape)


@_onnx_symbolic("aten::rotary_embedding")
@symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "i")
def rotary_embedding(g, input, position_ids, sin_cache, cos_cache, interleaved, rotary_embedding_dim, num_heads):
return g.op(
"RotaryEmbedding",
input,
position_ids,
sin_cache,
cos_cache,
interleaved_i=interleaved,
rotary_embedding_dim_i=rotary_embedding_dim,
num_heads_i=num_heads
)
4 changes: 2 additions & 2 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def export(

Models exported this way are probably runnable only by Caffe2.

opset_version (int, default 17): The version of the
opset_version (int, default 21): The version of the
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
to target. Must be >= 7 and <= 17.
do_constant_folding: Apply the constant-folding optimization.
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def _export(
if opset_version is None:
opset_version = _constants.ONNX_DEFAULT_OPSET

# torch.onnx.export does not support opset versions >=18
# torch.onnx.export does not support opset versions >=21
if opset_version > _constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET:
# We do not want to fail because we should still allow users to create
# custom symbolic functions for opset>17
Expand Down
Loading
0