Skip to content

Commit 1e8dc30

Browse files
authored
Serialize/flatbuffer to program (pytorch#18129)
exir: add flatbuffer-to-program reader This continues the work from pytorch#17333. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani --------- Signed-off-by: Chizkiyahu Raful <chizkiyahu.raful@arm.com>
1 parent 8be91e0 commit 1e8dc30

6 files changed

Lines changed: 228 additions & 208 deletions

File tree

exir/_serialize/_flatbuffer.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import importlib.resources
1313
import os
1414
import re
15-
import shutil
1615
import stat
1716
import subprocess
1817
import tempfile
@@ -384,72 +383,6 @@ def _flatc_decompile(
384383
)
385384

386385

387-
def _program_json_to_flatbuffer(
388-
program_json: str,
389-
*,
390-
constant_tensor_alignment: Optional[int] = None,
391-
delegate_alignment: Optional[int] = None,
392-
) -> _FlatbufferResult:
393-
"""Converts Program-compatible JSON into binary flatbuffer data.
394-
395-
Args:
396-
program_json: The JSON to convert. Must be compatible with the root
397-
table type of //executorch/schema/program.fbs.
398-
constant_tensor_alignment: If provided, the alignment to use for tensor
399-
data embedded in the output flatbuffer data. If not provided, uses
400-
the alignment in the schema.
401-
delegate_alignment: If provided, the alignment to use for delegate
402-
data embedded in the output flatbuffer data. If not provided, uses
403-
the alignment in the schema.
404-
405-
Returns: The flatbuffer data and associated metadata.
406-
"""
407-
with tempfile.TemporaryDirectory() as temp_dir:
408-
schema_info = _prepare_schema(
409-
out_dir=temp_dir,
410-
constant_tensor_alignment=constant_tensor_alignment,
411-
delegate_alignment=delegate_alignment,
412-
)
413-
file_stem = "data"
414-
json_path = os.path.join(temp_dir, file_stem + ".json")
415-
output_path = os.path.join(temp_dir, file_stem + ".pte")
416-
417-
with open(json_path, "wb") as json_file:
418-
json_file.write(program_json.encode("ascii"))
419-
420-
try:
421-
_flatc_compile(temp_dir, schema_info.root_path, json_path)
422-
except Exception as err:
423-
# It's helpful to save the breaking files for debugging. Optionally
424-
# move them out of the auto-deleting temporary directory. Don't do
425-
# this by default because some input files can be many GB in size,
426-
# and these copies won't be auto-deleted.
427-
should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"}
428-
extra_message = ""
429-
if should_save:
430-
try:
431-
saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-")
432-
for f in os.listdir(temp_dir):
433-
shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir)
434-
extra_message += f" Moved input files to '{saved_dir}'."
435-
except Exception as err2:
436-
extra_message += (
437-
f" (Failed to save input files for debugging: {err2})"
438-
)
439-
else:
440-
extra_message += (
441-
f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure."
442-
)
443-
444-
raise RuntimeError(
445-
f"Failed to compile {json_path} to {output_path}." + extra_message
446-
) from err
447-
with open(output_path, "rb") as output_file:
448-
return _FlatbufferResult(
449-
data=output_file.read(), max_alignment=schema_info.max_alignment
450-
)
451-
452-
453386
def _replace_infinity_in_json_file(content: bytes) -> bytes:
454387
"""Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
455388
is used to convert from flatbuffer to JSON. +-inf float values are not

exir/_serialize/_flatbuffer_program.py

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import enum
99
import functools
1010
import importlib
11+
import pkgutil
1112
import tempfile
1213

1314
from contextvars import ContextVar
1415
from dataclasses import fields, is_dataclass
1516
from functools import lru_cache
16-
from typing import Any, Dict, Optional
17+
from types import ModuleType
18+
from typing import Any, Dict, get_args, get_origin, get_type_hints, Optional, Union
1719

1820
import flatbuffers # pyre-ignore[21]
1921
from executorch.exir._serialize._flatbuffer import (
@@ -22,6 +24,7 @@
2224
_prepare_schema,
2325
_SchemaInfo,
2426
)
27+
from executorch.exir._serialize.generated import executorch_flatbuffer as _generated_fb
2528
from executorch.exir._serialize.generated.executorch_flatbuffer import (
2629
BackendDelegateInlineData as _BackendDelegateInlineData,
2730
Buffer as _Buffer,
@@ -33,6 +36,7 @@
3336

3437
_T_CLASS_CACHE: Dict[type, type] = {}
3538
_FIELD_NAME_CACHE: Dict[type, tuple[tuple[str, str], ...]] = {}
39+
_TYPE_HINTS_CACHE: Dict[type, Dict[str, Any]] = {}
3640
_BUFFER_ALIGNMENT: ContextVar[int] = ContextVar("_BUFFER_ALIGNMENT", default=1)
3741
_DELEGATE_ALIGNMENT: ContextVar[int] = ContextVar("_DELEGATE_ALIGNMENT", default=1)
3842

@@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]:
6468
return mapping
6569

6670

71+
def _dataclass_type_hints(dataclass_type: type) -> Dict[str, Any]:
72+
cached = _TYPE_HINTS_CACHE.get(dataclass_type)
73+
if cached is not None:
74+
return cached
75+
type_hints = get_type_hints(dataclass_type)
76+
_TYPE_HINTS_CACHE[dataclass_type] = type_hints
77+
return type_hints
78+
79+
6780
def _create_aligned_byte_vector(builder: Any, data: bytes, alignment: int) -> int:
6881
if not _is_valid_alignment(alignment):
6982
raise ValueError(f"Bad alignment {alignment}")
@@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT:
194207
return _convert_dataclass(val)
195208

196209

210+
# The generated FlatBuffer Python modules import child tables/unions as modules
211+
# (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the
212+
# unpacking helpers later expect those globals to be the corresponding classes.
213+
# Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the
214+
# generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types.
215+
def _patch_generated_module_aliases(module: ModuleType) -> None:
216+
for name, maybe_module in vars(module).items():
217+
if not isinstance(maybe_module, ModuleType):
218+
continue
219+
maybe_class = getattr(maybe_module, name, None)
220+
if isinstance(maybe_class, type):
221+
setattr(module, name, maybe_class)
222+
223+
224+
@lru_cache(maxsize=1)
225+
def _patch_generated_flatbuffer_aliases() -> None:
226+
package_name = _generated_fb.__name__
227+
for module_info in pkgutil.iter_modules(_generated_fb.__path__):
228+
module = importlib.import_module(f"{package_name}.{module_info.name}")
229+
_patch_generated_module_aliases(module)
230+
231+
232+
def _flatbuffer_dataclass_names(val: Any) -> tuple[str, Optional[str]]:
233+
val_type_name = type(val).__name__
234+
if val_type_name.endswith("T"):
235+
return val_type_name, val_type_name[:-1]
236+
return val_type_name, None
237+
238+
239+
def _matches_dataclass_union_type(
240+
union_type: Any, val_type_name: str, val_dataclass_name: Optional[str]
241+
) -> bool:
242+
if not is_dataclass(union_type):
243+
return False
244+
union_name = union_type.__name__
245+
return union_name == val_type_name or (
246+
val_dataclass_name is not None and union_name == val_dataclass_name
247+
)
248+
249+
250+
def _matches_non_dataclass_union_type(union_type: Any, val: Any) -> bool:
251+
if union_type is Any:
252+
return True
253+
if union_type is str and isinstance(val, (bytes, bytearray, memoryview)):
254+
return True
255+
union_origin = get_origin(union_type)
256+
if union_origin is list and hasattr(val, "__iter__"):
257+
return True
258+
return isinstance(union_type, type) and isinstance(val, union_type)
259+
260+
261+
def _union_choice_from_value(union_types: tuple[Any, ...], val: Any) -> Any:
262+
if val is None:
263+
for union_type in union_types:
264+
if union_type is type(None):
265+
return union_type
266+
return None
267+
268+
val_type_name, val_dataclass_name = _flatbuffer_dataclass_names(val)
269+
270+
for union_type in union_types:
271+
if union_type is type(None):
272+
continue
273+
if _matches_dataclass_union_type(union_type, val_type_name, val_dataclass_name):
274+
return union_type
275+
if _matches_non_dataclass_union_type(union_type, val):
276+
return union_type
277+
return None
278+
279+
280+
def _convert_from_flatbuffer_value(val: Any, expected_type: Any) -> Any:
281+
if val is None:
282+
return None
283+
284+
origin = get_origin(expected_type)
285+
if origin is list:
286+
item_type = get_args(expected_type)[0]
287+
return [_convert_from_flatbuffer_value(item, item_type) for item in val]
288+
289+
if origin is Union:
290+
union_type = _union_choice_from_value(get_args(expected_type), val)
291+
if union_type is None:
292+
raise TypeError(
293+
f"Could not match value type {type(val)} to {expected_type}"
294+
)
295+
if union_type is type(None):
296+
return None
297+
return _convert_from_flatbuffer_value(val, union_type)
298+
299+
if expected_type is bytes:
300+
return _coerce_bytes(val)
301+
if expected_type is str and isinstance(val, (bytes, bytearray, memoryview)):
302+
return _coerce_bytes(val).decode("utf-8")
303+
if is_dataclass(expected_type):
304+
return _convert_from_flatbuffer_dataclass(val, expected_type)
305+
if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum):
306+
if isinstance(val, expected_type):
307+
return val
308+
return expected_type(val)
309+
if isinstance(expected_type, type):
310+
return expected_type(val)
311+
return val
312+
313+
314+
def _convert_from_flatbuffer_dataclass(val: Any, dataclass_type: type) -> Any:
315+
result = {}
316+
type_hints = _dataclass_type_hints(dataclass_type)
317+
for src_name, dst_name in _dataclass_field_map(dataclass_type):
318+
result[src_name] = _convert_from_flatbuffer_value(
319+
getattr(val, dst_name), type_hints[src_name]
320+
)
321+
return dataclass_type(**result)
322+
323+
324+
def _flatbuffer_to_program(program_data: bytes) -> Program:
325+
_patch_generated_flatbuffer_aliases()
326+
program_t = ProgramT.InitFromPackedBuf(program_data)
327+
return _convert_from_flatbuffer_dataclass(program_t, Program)
328+
329+
197330
@lru_cache(maxsize=1)
198331
def _get_schema_info(
199332
constant_tensor_alignment: Optional[int], delegate_alignment: Optional[int]
@@ -213,11 +346,7 @@ def _program_to_flatbuffer(
213346
constant_tensor_alignment: Optional[int] = None,
214347
delegate_alignment: Optional[int] = None,
215348
) -> _FlatbufferResult:
216-
"""Converts a Program dataclass into binary flatbuffer data.
217-
218-
Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke
219-
flatc to build the binary.
220-
"""
349+
"""Converts a Program dataclass into binary flatbuffer data."""
221350
schema_info = _get_schema_info(constant_tensor_alignment, delegate_alignment)
222351
_set_pack_alignments(schema_info.tensor_alignment, schema_info.delegate_alignment)
223352
_install_fast_packers()

exir/_serialize/_program.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple
1717

1818
from executorch.exir._serialize._cord import Cord
19-
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
20-
from executorch.exir._serialize._flatbuffer import (
21-
_FlatbufferResult,
22-
_program_flatbuffer_to_json,
19+
from executorch.exir._serialize._dataclass import _DataclassEncoder
20+
from executorch.exir._serialize._flatbuffer import _FlatbufferResult
21+
from executorch.exir._serialize._flatbuffer_program import (
22+
_flatbuffer_to_program,
23+
_program_to_flatbuffer,
2324
)
24-
from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer
2525
from executorch.exir._serialize._named_data_store import (
2626
NamedDataStore,
2727
NamedDataStoreOutput,
@@ -86,12 +86,6 @@ def _program_to_json(program: Program) -> str:
8686
return json.dumps(program, cls=_DataclassEncoder)
8787

8888

89-
def _json_to_program(program_json: bytes) -> Program:
90-
"""Returns a Program deserialized from the given JSON string."""
91-
# construct program class recursively from dict
92-
return _json_to_dataclass(json.loads(program_json), cls=Program)
93-
94-
9589
def _insert_flatbuffer_header(
9690
flatbuffer_data: bytes, magic_regex: str, header_data: bytes
9791
) -> bytes:
@@ -757,9 +751,7 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile:
757751
segment_base_offset = eh.segment_base_offset
758752

759753
# Parse the flatbuffer data.
760-
program: Program = _json_to_program(
761-
_program_flatbuffer_to_json(program_data[:program_size])
762-
)
754+
program: Program = _flatbuffer_to_program(program_data[:program_size])
763755

764756
if segment_base_offset != 0:
765757
# Move segment data back into the Program.
@@ -799,9 +791,7 @@ def _extract_delegate_payload(
799791
program_size = len(pte_data)
800792

801793
# Parse the program flatbuffer
802-
program: Program = _json_to_program(
803-
_program_flatbuffer_to_json(pte_data[:program_size])
804-
)
794+
program: Program = _flatbuffer_to_program(pte_data[:program_size])
805795

806796
# Search for the matching delegate
807797
match_count = 0

0 commit comments

Comments
 (0)