8000 [ONNX] Migrate torchlib into PyTorch · pytorch/pytorch@d3904c9 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3904c9

Browse files
committed
[ONNX] Migrate torchlib into PyTorch
1 parent e6c39d3 commit d3904c9

23 files changed

+20654
-140
lines changed

test/onnx/torchlib/README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Test op correctness by comparing with PyTorch results using OpInfo
2+
3+
`OpInfo` is PyTorch's standard mechanism for composing test data for operators.
4+
Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362.
5+
6+
## Usage
7+
8+
```bash
9+
# All
10+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py
11+
12+
# To run tests on a specific operator (e.g. torch.ceil):
13+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil
14+
15+
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
16+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention
17+
```
18+
19+
### Environment variables
20+
21+
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
22+
in onnxruntime by running the inference sessions in a separate process.
23+
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g.
24+
25+
```bash
26+
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k div_mode_int
27+
```
28+
29+
## How to add a new operator test
30+
31+
See _usage_ in [`ops_test_data.py`](./ops_test_data.py)
32+
33+
## How to add custom OpInfo tests
34+
35+
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
36+
37+
Follow the steps below to create new OpInfo tests:
38+
39+
1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py)
40+
41+
```py
42+
opinfo_core.OpInfo(
43+
"ops.aten.slice_scatter",
44+
aten_name="slice_scatter",
45+
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
46+
sample_inputs_func=sample_inputs_slice_scatter,
47+
supports_out=False,
48+
),
49+
```
50+
51+
- The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`.
52+
- Follow existing examples to specify the `dtypes` you want to test the op on.
53+
- Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068.
54+
55+
```py
56+
opinfo_core.OpInfo(
57+
"ops.aten.bernoulli.p_deterministic",
58+
op=torch.ops.aten.bernoulli.p,
59+
```
60+
61+
The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name.
62+
63+
2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268)
64+
1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
65+
66+
```py
67+
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})
68+
```
69+
70+
`input` is the first arg. The rest of the args are in `args`.
71+
3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py)
72+
1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116)
73+
74+
```py
75+
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)
76+
```
77+
78+
You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590).
79+
80+
Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails.
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import difflib
6+
import pathlib
7+
import platform
8+
import sys
9+
import time
10+
import traceback
11+
from typing import Any, Mapping
12+
13+
import numpy as np
14+
15+
import onnx
16+
import onnxruntime as ort
17+
18+
import torch
19+
10000 20+
21+
_REPRODUCTION_TEMPLATE = '''\
22+
import google.protobuf.text_format
23+
import numpy as np
24+
from numpy import array, float16, float32, float64, int32, int64
25+
import onnx
26+
import onnxruntime as ort
27+
28+
# Run n times
29+
N = 1
30+
31+
onnx_model_text = """
32+
{onnx_model_text}
33+
"""
34+
35+
ort_inputs = {ort_inputs}
36+
37+
# Set up the inference session
38+
session_options = ort.SessionOptions()
39+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
40+
onnx_model = onnx.ModelProto()
41+
google.protobuf.text_format.Parse(onnx_model_text, onnx_model)
42+
43+
# Uncomment this line to save the model to a file for examination
44+
# onnx.save_model(onnx_model, "{short_test_name}.onnx")
45+
46+
onnx.checker.check_model(onnx_model)
47+
session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",))
48+
49+
# Run the model
50+
for _ in range(N):
51+
ort_outputs = session.run(None, ort_inputs)
52+
'''
53+
54+
_ISSUE_MARKDOWN_TEMPLATE = """
55+
### Summary
56+
57+
ONNX Runtime raises `{error_text}` when executing test `{test_name}` in ONNX Script `TorchLib`.
58+
59+
To recreate this report, use
60+
61+
```bash
62+
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
63+
```
64+
65+
### To reproduce
66+
67+
```python
68+
{reproduction_code}
69+
```
70+
71+
### Full error stack
72+
73+
```
74+
{error_stack}
75+
```
76+
77+
### The ONNX model text for visualization
78+
79+
```
80+
{onnx_model_textual_representation}
81+
```
82+
83+
### Environment
84+
85+
```
86+
{sys_info}
87+
```
88+
"""
89+
90+
91+
_MISMATCH_MARKDOWN_TEMPLATE = """\
92+
### Summary
93+
94+
The output of ONNX Runtime does not match that of PyTorch when executing test
95+
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
96+
97+
To recreate this report, use
98+
99+
```bash
100+
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
101+
```
102+
103+
### Inputs
104+
105+
Shapes: `{input_shapes}`
106+
107+
<details><summary>Details</summary>
108+
<p>
109+
110+
```python
111+
kwargs = {kwargs}
112+
inputs = {inputs}
113+
```
114+
115+
</p>
116+
</details>
117+
118+
### Expected output
119+
120+
Shape: `{expected_shape}`
121+
122+
<details><summary>Details</summary>
123+
<p>
124+
125+
```python
126+
expected = {expected}
127+
```
128+
129+
</p>
130+
</details>
131+
132+
### Actual output
133+
134+
Shape: `{actual_shape}`
135+
136+
<details><summary>Details</summary>
137+
<p>
138+
139+
```python
140+
actual = {actual}
141+
```
142+
143+
</p>
144+
</details>
145+
146+
### Difference
147+
148+
<details><summary>Details</summary>
149+
<p>
150+
151+
```diff
152+
{diff}
153+
```
154+
155+
</p>
156+
</details>
157+
158+
### Full error stack
159+
160+
```
161+
{error_stack}
162+
```
163+
"""
164+
165+
166+
def create_reproduction_report(
167+
test_name: str,
168+
onnx_model: onnx.ModelProto,
169+
ort_inputs: Mapping[str, Any],
170+
error: Exception,
171+
) -> None:
172+
# NOTE: We choose to embed the ONNX model as a string in the report instead of
173+
# saving it to a file because it is easier to share the report with others.
174+
onnx_model_text = str(onnx_model)
175+
with np.printoptions(threshold=sys.maxsize):
176+
ort_inputs = dict(ort_inputs.items())
177+
input_text = str(ort_inputs)
178+
error_text = str(error)
179+
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
180+
sys_info = f"""\
181+
OS: {platform.platform()}
182+
Python version: {sys.version}
183+
onnx=={onnx.__version__}
184+
onnxruntime=={ort.__version__}
185+
numpy=={np.__version__}
186+
torch=={torch.__version__}"""
187+
short_test_name = test_name.split(".")[-1]
188+
reproduction_code = _REPRODUCTION_TEMPLATE.format(
189+
onnx_model_text=onnx_model_text,
190+
ort_inputs=input_text,
191+
short_test_name=short_test_name,
192+
)
193+
onnx_model_textual_representation = onnx.printer.to_text(onnx_model)
194+
195+
markdown = _ISSUE_MARKDOWN_TEMPLATE.format(
196+
error_text=error_text,
197+
test_name=test_name,
198+
short_test_name=short_test_name,
199+
reproduction_code=reproduction_code,
200+
error_stack=error_stack,
201+
sys_info=sys_info,
202+
onnx_model_textual_representation=onnx_model_textual_representation,
203+
)
204+
205+
# Turn test name into a valid file name
206+
markdown_file_name = f"{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md"
207+
markdown_file_path = save_error_report(markdown_file_name, markdown)
208+
print(f"Created reproduction report at {markdown_file_path}")
209+
210+
211+
def create_mismatch_report(
212+
test_name: str,
213+
sample_num: int,
214+
inputs,
215+
kwargs,
216+
actual,
217+
expected,
218+
error: Exception,
219+
) -> None:
220+
torch.set_printoptions(threshold=sys.maxsize)
221+
222+
error_text = str(error)
223+
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
224+
short_test_name = test_name.split(".")[-1]
225+
diff = difflib.unified_diff(
226+
str(actual).splitlines(),
227+
str(expected).splitlines(),
228+
fromfile="actual",
229+
tofile="expected",
230+
lineterm="",
231+
)
232+
input_shapes = repr(
233+
[
234+
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
235+
if isinstance(inp, torch.Tensor)
236+
else inp
237+
for inp in inputs
238+
]
239+
)
240+
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
241+
test_name=test_name,
242+
short_test_name=short_test_name,
243+
sample_num=sample_num,
244+
input_shapes=input_shapes,
245+
inputs=inputs,
246+
kwargs=kwargs,
247+
expected=expected,
248+
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
249+
actual=actual,
250+
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
251+
diff="\n".join(diff),
252+
error_stack=error_stack,
253+
)
254+
255+
markdown_file_name = f"mismatch-{short_test_name.replace('/', '-').replace(':', '-')}-{str(time.time()).replace('.', '_')}.md"
256+
markdown_file_path = save_error_report(markdown_file_name, markdown)
257+
print(f"Created reproduction report at {markdown_file_path}")
258+
259+
260+
def save_error_report(file_name: str, text: str):
261+
reports_dir = pathlib.Path("error_reports")
262+
reports_dir.mkdir(parents=True, exist_ok=True)
263+
file_path = reports_dir / file_name
264+
with open(file_path, "w", encoding="utf-8") as f:
265+
f.write(text)
266+
267+
return file_path

0 commit comments

Comments
 (0)
0