8000 [ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor (#151259) · pytorch/pytorch@9917fef · GitHub
[go: up one dir, main page]

Skip to content

Commit 9917fef

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] Produce correct dtypes for bf16/f8 in IR TorchTensor (#151259)
Split the changes from #151069 to address microsoft/onnxscript#2187, where the output np arrays do not have the correct ml_dtypes types as expected. Pull Request resolved: #151259 Approved by: https://github.com/titaiwangms
1 parent 331423e commit 9917fef

File tree

4 files changed

+40
-39
lines changed

4 files changed

+40
-39
lines changed

test/onnx/exporter/test_core.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import ml_dtypes
67
import numpy as np
78

89
import torch
@@ -15,17 +16,17 @@ class TorchTensorTest(common_utils.TestCase):
1516
@common_utils.parametrize(
1617
"dtype, np_dtype",
1718
[
18-
(torch.bfloat16, np.uint16),
19+
(torch.bfloat16, ml_dtypes.bfloat16),
1920
(torch.bool, np.bool_),
2021
(torch.complex128, np.complex128),
2122
(torch.complex64, np.complex64),
2223
(torch.float16, np.float16),
2324
(torch.float32, np.float32),
2425
(torch.float64, np.float64),
25-
(torch.float8_e4m3fn, np.uint8),
26-
(torch.float8_e4m3fnuz, np.uint8),
27-
(torch.float8_e5m2, np.uint8),
28-
(torch.float8_e5m2fnuz, np.uint8),
26+
(torch.float8_e4m3fn, ml_dtypes.float8_e4m3fn),
27+
(torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz),
28+
(torch.float8_e5m2, ml_dtypes.float8_e5m2),
29+
(torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz),
2930
(torch.int16, np.int16),
3031
(torch.int32, np.int32),
3132
(torch.int64, np.int64),
@@ -45,25 +46,25 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
4546
@common_utils.parametrize(
4647
"dtype",
4748
[
48-
(torch.bfloat16),
49-
(torch.bool),
50-
(torch.complex128),
51-
(torch.complex64),
52-
(torch.float16),
53-
(torch.float32),
54-
(torch.float64),
55-
(torch.float8_e4m3fn),
56-
(torch.float8_e4m3fnuz),
57-
(torch.float8_e5m2),
58-
(torch.float8_e5m2fnuz),
59-
(torch.int16),
60-
(torch.int32),
61-
(torch.int64),
62-
(torch.int8),
63-
(torch.uint16),
64-
(torch.uint32),
65-
(torch.uint64),
66-
(torch.uint8),
49+
torch.bfloat16,
50+
torch.bool,
51+
torch.complex128,
52+
torch.complex64,
53+
torch.float16,
54+
torch.float32,
55+
torch.float64,
56+
torch.float8_e4m3fn,
57+
torch.float8_e4m3fnuz,
58+
torch.float8_e5m2,
59+
torch.float8_e5m2fnuz,
60+
torch.int16,
61+
torch.int32,
62+
torch.int64,
63+
torch.int8,
64+
torch.uint16,
65+
torch.uint32,
66+
torch.uint64,
67+
torch.uint8,
6768
],
6869
)
6970
def test_tobytes(self, dtype: torch.dtype):

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,7 @@ def false_fn(x, z):
150150
x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
151151
return x, z
152152

153-
onnx_program = torch.onnx.export(
154-
CondModel(),
155-
(torch.tensor([1, 2]),),
156-
dynamo=True,
157-
fallback=False,
158-
)
153+
onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
159154
onnx_testing.assert_onnx_program(onnx_program)
160155
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
161156

@@ -194,27 +189,27 @@ def forward(self, x):
194189
_ = self.export(exported_program)
195190

196191
@common_utils.parametrize(
197-
"float8_type",
192+
"float8_type, onnx_type",
198193
[
199194
common_utils.subtest(
200-
torch.float8_e5m2,
195+
(torch.float8_e5m2, ir.DataType.FLOAT8E5M2),
201196
name="torch_float8_e5m2",
202197
),
203198
common_utils.subtest(
204-
torch.float8_e5m2fnuz,
F438 199+
(torch.float8_e5m2fnuz, ir.DataType.FLOAT8E5M2FNUZ),
205200
name="torch_float8_e5m2fnuz",
206201
),
207202
common_utils.subtest(
208-
torch.float8_e4m3fn,
203+
(torch.float8_e4m3fn, ir.DataType.FLOAT8E4M3FN),
209204
name="torch_float8_e4m3fn",
210205
),
211206
common_utils.subtest(
212-
torch.float8_e4m3fnuz,
207+
(torch.float8_e4m3fnuz, ir.DataType.FLOAT8E4M3FNUZ),
213208
name="torch_float8_e4m3fnuz",
214209
),
215210
],
216211
)
217-
def test_float8_support(self, float8_type):
212+
def test_float8_support(self, float8_type: torch.dtype, onnx_type: ir.DataType):
218213
class Float8Module(torch.nn.Module):
219214
def forward(self, input: torch.Tensor):
220215
input = input.to(float8_type)

torch/onnx/_internal/exporter/_core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,17 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None):
116116
def numpy(self) -> npt.NDArray:
117117
self.raw: torch.Tensor
118118
if self.dtype == ir.DataType.BFLOAT16:
119-
return self.raw.view(torch.uint16).numpy(force=True)
119+
return (
120+
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
121+
)
120122
if self.dtype in {
121123
ir.DataType.FLOAT8E4M3FN,
122124
ir.DataType.FLOAT8E4M3FNUZ,
123125
ir.DataType.FLOAT8E5M2,
124126
ir.DataType.FLOAT8E5M2FNUZ,
125127
}:
126-
# TODO: Use ml_dtypes
127-
return self.raw.view(torch.uint8).numpy(force=True)
128+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
9EE8 129+
128130
return self.raw.numpy(force=True)
129131

130132
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:

torch/onnx/_internal/exporter/_dispatching.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
torch.int64: ir.DataType.INT64,
3333
torch.int8: ir.DataType.INT8,
3434
torch.uint8: ir.DataType.UINT8,
35+
torch.uint16: ir.DataType.UINT16,
36+
torch.uint32: ir.DataType.UINT32,
37+
torch.uint64: ir.DataType.UINT64,
3538
}
3639

3740

0 commit comments

Comments
 (0)
0