Skip to content

Commit 8debe93

Browse files
Add FP8 placeholder support to ExecuTorch serialization (pytorch#19043)
Change-Id: Ibb7ef4167ab96426133fce64e34366c365cd12ad Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 3b5d18d commit 8debe93

5 files changed

Lines changed: 22 additions & 1 deletion

File tree

exir/print_program.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -60,6 +61,8 @@ def _scalar_type_str(scalar_type: ScalarType) -> str:
6061
ScalarType.QUINT8: "qui8",
6162
ScalarType.QINT32: "qi32",
6263
ScalarType.BFLOAT16: "bf16",
64+
ScalarType.FLOAT8E5M2: "f8e5m2",
65+
ScalarType.FLOAT8E4M3FN: "f8e4m3fn",
6366
ScalarType.QUINT4x2: "qui4x2",
6467
ScalarType.QUINT2x4: "qui2x4",
6568
}

exir/serde/export_serialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -149,6 +150,8 @@ def _reverse_map(d: Dict[Any, Enum]):
149150
torch.complex128: ScalarType.COMPLEXDOUBLE,
150151
torch.bool: ScalarType.BOOL,
151152
torch.bfloat16: ScalarType.BFLOAT16,
153+
torch.float8_e5m2: ScalarType.FLOAT8E5M2,
154+
torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN,
152155
torch.uint16: ScalarType.UINT16
153156
}
154157

exir/serde/schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -16,7 +17,7 @@
1617
from executorch.exir.serde.union import _Union
1718

1819
# NOTE: Please update this value if any modifications are made to the schema
19-
SCHEMA_VERSION = (5, 3)
20+
SCHEMA_VERSION = (5, 4)
2021
TREESPEC_VERSION = 1
2122

2223

@@ -36,6 +37,8 @@ class ScalarType(IntEnum):
3637
BOOL = 12
3738
BFLOAT16 = 13
3839
UINT16 = 14
40+
FLOAT8E5M2 = 15
41+
FLOAT8E4M3FN = 16
3942

4043
class Layout(IntEnum):
4144
Unknown = 0

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -292,6 +293,8 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
292293
torch.qint32: ScalarType.QINT32,
293294
torch.bfloat16: ScalarType.BFLOAT16,
294295
torch.quint4x2: ScalarType.QUINT4x2,
296+
torch.float8_e5m2: ScalarType.FLOAT8E5M2,
297+
torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN,
295298
torch.uint16: ScalarType.UINT16,
296299
torch.uint32: ScalarType.UINT32,
297300
}

exir/tests/test_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -90,6 +91,14 @@ def test_normal_tensor_conversion(self) -> None:
9091
# whereas strides for torch.memory_format = torch.channels_last is
9192
# (3*4*5, 1, 5*3, 3))
9293

94+
def test_fp8_tensor_conversion(self) -> None:
95+
for dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
96+
normal_tensor = torch.randn(2, 2, 3, dtype=torch.float32).to(dtype)
97+
flatbuffer_tensor = make_tensor_value(
98+
1, 0, TensorSpec.from_tensor(normal_tensor)
99+
)
100+
self.compare_tensors(normal_tensor, flatbuffer_tensor)
101+
93102
def test_allocation_info_succeeds(self) -> None:
94103
test_cases = (
95104
(

0 commit comments

Comments
 (0)