8000 [ONNX] Support float4 by justinchuby · Pull Request #151069 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[ONNX] Support float4 #151069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion test/onnx/exporter/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class TorchTensorTest(common_utils.TestCase):
(torch.uint32, np.uint32),
(torch.uint64, np.uint64),
(torch.uint8, np.uint8),
(torch.float4_e2m1fn_x2, ml_dtypes.float4_e2m1fn),
],
)
def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
if dtype == torch.float4_e2m1fn_x2:
tensor = _core.TorchTensor(torch.tensor([1], dtype=torch.uint8).view(dtype))
else:
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.numpy().dtype, np_dtype)
self.assertEqual(tensor.__array__().dtype, np_dtype)
self.assertEqual(np.array(tensor).dtype, np_dtype)
Expand Down Expand Up @@ -71,6 +75,12 @@ def test_tobytes(self, dtype: torch.dtype):
tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype))
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())

def test_tobytes_float4(self):
tensor = _core.TorchTensor(
torch.tensor([1], dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
)
self.assertEqual(tensor.tobytes(), b"\x01")


if __name__ == "__main__":
common_utils.run_tests()
14 changes: 13 additions & 1 deletion test/onnx/exporter/test_small_models_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,19 @@ def forward(self, input: torch.Tensor):
input = input.to(float8_type)
return input

_ = self.export(Float8Module(), (torch.randn(1, 2),))
onnx_program = self.export(Float8Module(), (torch.randn(1, 2),))
self.assertEqual(onnx_program.model.graph.outputs[0].dtype, onnx_type)

def test_float4_support(self):
class Float4Module(torch.nn.Module):
def forward(self):
return torch.empty([1], dtype=torch.float4_e2m1fn_x2)

onnx_program = self.export(Float4Module())
output = onnx_program.model.graph.outputs[0]
self.assertEqual(output.dtype, ir.DataType.FLOAT4E2M1)
# The shape is [*shape, 2] because ONNX stores the shape of the unpacked tensor
self.assertEqual(output.shape.dims, [1, 2])

def test_bfloat16_support(self):
class BfloatModel(torch.nn.Module):
Expand Down
25 changes: 23 additions & 2 deletions torch/onnx/_internal/exporter/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_registration,
_reporting,
_tensors,
_type_casting,
_verification,
)

Expand All @@ -61,6 +62,7 @@
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
Expand Down Expand Up @@ -109,8 +111,17 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:
class TorchTensor(ir.Tensor):
def __init__(self, tensor: torch.Tensor, name: str | None = None):
# Pass the tensor as the raw data to ir.Tensor's constructor
if tensor.dtype == torch.float4_e2m1fn_x2:
# Change the shape to the unpacked shape
shape = ir.Shape(_type_casting.get_float4_shape(tensor), frozen=True)
else:
# The base class will set the shape to the tensor's shape
shape = None
super().__init__(
tensor, dtype=torch_dtype_to_onnx_dtype(tensor.dtype), name=name
tensor,
dtype=torch_dtype_to_onnx_dtype(tensor.dtype),
shape=shape,
name=name,
)

def numpy(self) -> npt.NDArray:
Expand All @@ -132,6 +143,10 @@ def numpy(self) -> npt.NDArray:
ir.DataType.FLOAT8E5M2FNUZ,
}:
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
if self.dtype == ir.DataType.FLOAT4E2M1:
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
self.dtype.numpy()
)

return self.raw.numpy(force=True)

Expand Down Expand Up @@ -213,7 +228,13 @@ def _set_shape_type(
logger.warning("Setting shape and type of tensors is not supported yet")
if isinstance(meta_val, torch.Tensor):
dims = []
for dim in meta_val.shape:
shape: tuple[int, ...]
if meta_val.dtype == torch.float4_e2m1fn_x2:
# Change the shape to the unpacked shape
shape = _type_casting.get_float4_shape(meta_val)
else:
shape = meta_val.shape
for dim in shape:
if isinstance(dim, int):
dims.append(dim)
else:
Expand Down
3 changes: 3 additions & 0 deletions torch/onnx/_internal/exporter/_dispatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
torch.float4_e2m1fn_x2: ir.DataType.FLOAT4E2M1,
torch.int16: ir.DataType.INT16,
torch.int32: ir.DataType.INT32,
torch.int64: ir.DataType.INT64,
Expand Down Expand Up @@ -95,6 +96,7 @@ def _param_type_compatible_with_arg(
ir.TensorType(ir.DataType.INT32),
ir.TensorType(ir.DataType.INT64),
# Int inputs can be casted to a float too
ir.TensorType(ir.DataType.FLOAT4E2M1),
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),
Expand All @@ -105,6 +107,7 @@ def _param_type_compatible_with_arg(
}:
return True
if isinstance(value, float) and param.type_constraint.allowed_types & {
ir.TensorType(ir.DataType.FLOAT4E2M1),
ir.TensorType(ir.DataType.FLOAT8E4M3FN),
ir.TensorType(ir.DataType.FLOAT8E4M3FNUZ),
ir.TensorType(ir.DataType.FLOAT8E5M2),
Expand Down
32 changes: 32 additions & 0 deletions torch/onnx/_internal/exporter/_type_casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np

import torch


def unpack_float4x2_as_uint8(tensor: torch.Tensor) -> np.ndarray:
"""Convert a float4x2 tensor to unpacked uint8 np array."""
assert tensor.dtype == torch.float4_e2m1fn_x2
data = tensor.view(torch.uint8).numpy(force=True).flatten()
result_size = tensor.numel() * 2
result = np.empty([result_size], dtype=np.uint8)
array_low = data & np.uint8(0x0F)
array_high = data & np.uint8(0xF0)
array_high >>= np.uint8(4)
result[0::2] = array_low
result[1::2] = array_high
result.resize(get_float4_shape(tensor), refcheck=False)
return result


def get_float4_shape(tensor: torch.Tensor) -> tuple[int, ...]:
"""Get the shape of an unpacked float4 tensor.

The float4_e2m1fn_x2 type is a shell type described in
https://github.com/pytorch/pytorch/issues/146414.

the shell dtype is takes up 1 byte per element and semantically represents
two fp4 values packed into 1 byte. Semantically it represents (*tensor.shape, 2)
fp4 elements.
"""
assert tensor.dtype == torch.float4_e2m1fn_x2
return (*tensor.shape, 2)
3 changes: 3 additions & 0 deletions torch/onnx/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
torch.float8_e5m2: 19, # FLOAT8E5M2
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
# 21 = UINT4
# 22 = INT4
torch.float4_e2m1fn_x2: 23, # FLOAT4E2M1
}


Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/ops/_symbolic_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
21: torch.uint8, # UINT4
22: torch.uint8, # INT4
23: torch.uint8, # FLOAT4E2M1
23: torch.float4_e2m1fn_x2, # FLOAT4E2M1
}

_INT_TYPE = "i"
Expand Down
Loading
0