8000 update · pytorch/pytorch@52c6def · GitHub
[go: up one dir, main page]

Skip to content

Commit 52c6def

Browse files
committed
update
1 parent 88d02b1 commit 52c6def

File tree

3 files changed

+17
-39
lines changed

3 files changed

+17
-39
lines changed

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,7 @@ def false_fn(x, z):
150150
x = torch.cond(x.sum() > 0, true_fn, false_fn, (x, z))
151151
return x, z
152152

153-
onnx_program = torch.onnx.export(
154-
CondModel(),
155-
(torch.tensor([1, 2]),),
156-
dynamo=True,
157-
fallback=False,
158-
)
153+
onnx_program = self.export(CondModel(), (torch.tensor([1, 2]),))
159154
onnx_testing.assert_onnx_program(onnx_program)
160155
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([-1, -2]),))
161156

@@ -194,56 +189,34 @@ def forward(self, x):
194189
_ = self.export(exported_program)
195190

196191
@common_utils.parametrize(
197-
"float8_type",
192+
"float8_type, onnx_type",
198193
[
199194
common_utils.subtest(
200-
torch.float8_e5m2,
195+
(torch.float8_e5m2, ir.DataType.FLOAT8E5M2),
201196
name="torch_float8_e5m2",
202197
),
203198
common_utils.subtest(
204-
torch.float8_e5m2fnuz,
199+
(torch.float8_e5m2fnuz, ir.DataType.FLOAT8E5M2FNUZ),
205200
name="torch_float8_e5m2fnuz",
206201
),
207202
common_utils.subtest(
208-
torch.float8_e4m3fn,
203+
(torch.float8_e4m3fn, ir.DataType.FLOAT8E4M3FN),
209204
name="torch_float8_e4m3fn",
210205
),
211206
common_utils.subtest(
212-
torch.float8_e4m3fnuz,
207+
(torch.float8_e4m3fnuz, ir.DataType.FLOAT8E4M3FNUZ),
213208
name="torch_float8_e4m3fnuz",
214209
),
215210
],
216211
)
217-
def test_float8_support(self, float8_type):
212+
def test_float8_support(self, float8_type: torch.dtype, onnx_type: ir.DataType):
218213
class Float8Module(torch.nn.Module):
219214
def forward(self, input: torch.Tensor):
220215
input = input.to(float8_type)
221216
return input
222217

223-
_ = self.export(Float8Module(), (torch.randn(1, 2),))
224-
225-
def test_bfloat16_support(self):
226-
class BfloatModel(torch.nn.Module):
227-
def __init__(self):
228-
super().__init__()
229-
# Test parameters
230-
self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16))
231-
232-
def forward(self, x):
233-
# Test constant tensors are stored as bfloat16
234-
const = torch.tensor(1.0, dtype=torch.bfloat16)
235-
return x * const * self.param
236-
237-
input = torch.tensor([1.0, 2.0], dtype=torch.bfloat16)
238-
onnx_program = self.export(BfloatModel(), (input,), optimize=False)
239-
initializers = onnx_program.model.graph.initializers.values()
240-
self.assertEqual(len(initializers), 2)
241-
for initializer in initializers:
242-
self.assertEqual(initializer.dtype, ir.DataType.BFLOAT16)
243-
self.assertEqual(onnx_program.model.graph.inputs[0].dtype, ir.DataType.BFLOAT16)
244-
self.assertEqual(
245-
onnx_program.model.graph.outputs[0].dtype, ir.DataType.BFLOAT16
246-
)
218+
onnx_program = self.export(Float8Module(), (torch.randn(1, 2),))
219+
self.assertEqual(onnx_program.model.graph.outputs[0].dtype, onnx_type)
247220

248221
def test_export_with_logging_logger(self):
249222
logger = logging.getLogger(__name__)

torch/onnx/_internal/exporter/_core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,17 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None):
116116
def numpy(self) -> npt.NDArray:
117117
self.raw: torch.Tensor
118118
if self.dtype == ir.DataType.BFLOAT16:
119-
return self.raw.view(torch.uint16).numpy(force=True)
119+
return (
120+
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
121+
)
120122
if self.dtype in {
121123
ir.DataType.FLOAT8E4M3FN,
122124
ir.DataType.FLOAT8E4M3FNUZ,
123125
ir.DataType.FLOAT8E5M2,
124126
ir.DataType.FLOAT8E5M2FNUZ,
125127
}:
126-
# TODO: Use ml_dtypes
127-
return self.raw.view(torch.uint8).numpy(force=True)
128+
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
129+
128130
return self.raw.numpy(force=True)
129131

130132
def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray:

torch/onnx/_internal/exporter/_dispatching.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
torch.int64: ir.DataType.INT64,
3333
torch.int8: ir.DataType.INT8,
3434
torch.uint8: ir.DataType.UINT8,
35+
torch.uint16: ir.DataType.UINT16,
36+
torch.uint32: ir.DataType.UINT32,
37+
torch.uint64: ir.DataType.UINT64,
3538
}
3639

3740

0 commit comments

Comments
 (0)
0