Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 0 additions & 104 deletions exir/_serialize/_flatbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
import importlib.resources
import os
import re
import shutil
import subprocess

import tempfile

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence

Expand Down Expand Up @@ -313,104 +310,3 @@ def _flatc_decompile(
bin_path,
]
)


def _program_json_to_flatbuffer(
program_json: str,
*,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> _FlatbufferResult:
"""Converts Program-compatible JSON into binary flatbuffer data.

Args:
program_json: The JSON to convert. Must be compatible with the root
table type of //executorch/schema/program.fbs.
constant_tensor_alignment: If provided, the alignment to use for tensor
data embedded in the output flatbuffer data. If not provided, uses
the alignment in the schema.
delegate_alignment: If provided, the alignment to use for delegate
data embedded in the output flatbuffer data. If not provided, uses
the alignment in the schema.

Returns: The flatbuffer data and associated metadata.
"""
with tempfile.TemporaryDirectory() as temp_dir:
schema_info = _prepare_schema(
out_dir=temp_dir,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
)
file_stem = "data"
json_path = os.path.join(temp_dir, file_stem + ".json")
output_path = os.path.join(temp_dir, file_stem + ".pte")

with open(json_path, "wb") as json_file:
json_file.write(program_json.encode("ascii"))

try:
_flatc_compile(temp_dir, schema_info.root_path, json_path)
except Exception as err:
# It's helpful to save the breaking files for debugging. Optionally
# move them out of the auto-deleting temporary directory. Don't do
# this by default because some input files can be many GB in size,
# and these copies won't be auto-deleted.
should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"}
extra_message = ""
if should_save:
try:
saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-")
for f in os.listdir(temp_dir):
shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir)
extra_message += f" Moved input files to '{saved_dir}'."
except Exception as err2:
extra_message += (
f" (Failed to save input files for debugging: {err2})"
)
else:
extra_message += (
f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure."
)

raise RuntimeError(
f"Failed to compile {json_path} to {output_path}." + extra_message
) from err
with open(output_path, "rb") as output_file:
return _FlatbufferResult(
data=output_file.read(), max_alignment=schema_info.max_alignment
)


def _replace_infinity_in_json_file(content: bytes) -> bytes:
"""Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
is used to convert from flatbuffer to JSON. +-inf float values are not
supported by JSON, so we replace them with the string equivalent. When
converting from JSON to python dataclasses, the string is read as a Union
of float and string (see schema.py).
"""
content = re.sub(
rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content
)
return content


def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
"""Converts binary flatbuffer data into Program-compatible JSON.

The binary is parsed using the schema in //executorch/schema/program.fbs.
"""
with tempfile.TemporaryDirectory() as temp_dir:
# No need to patch the alignment when reading. "force_align" is only
# used during serialization.
schema_info = _prepare_schema(temp_dir)
file_stem = "data"
bin_path = os.path.join(temp_dir, file_stem + ".bin")
json_path = os.path.join(temp_dir, file_stem + ".json")

with open(bin_path, "wb") as bin_file:
bin_file.write(program_flatbuffer)

_flatc_decompile(temp_dir, schema_info.root_path, bin_path)
with open(json_path, "rb") as output_file:
json_data = output_file.read()
return _replace_infinity_in_json_file(json_data)
141 changes: 135 additions & 6 deletions exir/_serialize/_flatbuffer_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import enum
import functools
import importlib
import pkgutil
import tempfile

from contextvars import ContextVar
from dataclasses import fields, is_dataclass
from functools import lru_cache
from typing import Any, Dict, Optional
from types import ModuleType
from typing import Any, Dict, get_args, get_origin, get_type_hints, Optional, Union

import flatbuffers # pyre-ignore[21]
from executorch.exir._serialize._flatbuffer import (
Expand All @@ -22,6 +24,7 @@
_prepare_schema,
_SchemaInfo,
)
from executorch.exir._serialize.generated import executorch_flatbuffer as _generated_fb
from executorch.exir._serialize.generated.executorch_flatbuffer import (
BackendDelegateInlineData as _BackendDelegateInlineData,
Buffer as _Buffer,
Expand All @@ -33,6 +36,7 @@

_T_CLASS_CACHE: Dict[type, type] = {}
_FIELD_NAME_CACHE: Dict[type, tuple[tuple[str, str], ...]] = {}
_TYPE_HINTS_CACHE: Dict[type, Dict[str, Any]] = {}
_BUFFER_ALIGNMENT: ContextVar[int] = ContextVar("_BUFFER_ALIGNMENT", default=1)
_DELEGATE_ALIGNMENT: ContextVar[int] = ContextVar("_DELEGATE_ALIGNMENT", default=1)

Expand Down Expand Up @@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]:
return mapping


def _dataclass_type_hints(dataclass_type: type) -> Dict[str, Any]:
cached = _TYPE_HINTS_CACHE.get(dataclass_type)
if cached is not None:
return cached
type_hints = get_type_hints(dataclass_type)
_TYPE_HINTS_CACHE[dataclass_type] = type_hints
return type_hints


def _create_aligned_byte_vector(builder: Any, data: bytes, alignment: int) -> int:
if not _is_valid_alignment(alignment):
raise ValueError(f"Bad alignment {alignment}")
Expand Down Expand Up @@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT:
return _convert_dataclass(val)


# The generated FlatBuffer Python modules import child tables/unions as modules
# (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the
# unpacking helpers later expect those globals to be the corresponding classes.
# Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the
# generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types.
def _patch_generated_module_aliases(module: ModuleType) -> None:
for name, maybe_module in vars(module).items():
if not isinstance(maybe_module, ModuleType):
continue
maybe_class = getattr(maybe_module, name, None)
if isinstance(maybe_class, type):
setattr(module, name, maybe_class)


@lru_cache(maxsize=1)
def _patch_generated_flatbuffer_aliases() -> None:
package_name = _generated_fb.__name__
for module_info in pkgutil.iter_modules(_generated_fb.__path__):
module = importlib.import_module(f"{package_name}.{module_info.name}")
_patch_generated_module_aliases(module)


def _flatbuffer_dataclass_names(val: Any) -> tuple[str, Optional[str]]:
val_type_name = type(val).__name__
if val_type_name.endswith("T"):
return val_type_name, val_type_name[:-1]
return val_type_name, None


def _matches_dataclass_union_type(
union_type: Any, val_type_name: str, val_dataclass_name: Optional[str]
) -> bool:
if not is_dataclass(union_type):
return False
union_name = union_type.__name__
return union_name == val_type_name or (
val_dataclass_name is not None and union_name == val_dataclass_name
)


def _matches_non_dataclass_union_type(union_type: Any, val: Any) -> bool:
if union_type is Any:
return True
if union_type is str and isinstance(val, (bytes, bytearray, memoryview)):
return True
union_origin = get_origin(union_type)
if union_origin is list and hasattr(val, "__iter__"):
return True
return isinstance(union_type, type) and isinstance(val, union_type)


def _union_choice_from_value(union_types: tuple[Any, ...], val: Any) -> Any:
if val is None:
for union_type in union_types:
if union_type is type(None):
return union_type
return None

val_type_name, val_dataclass_name = _flatbuffer_dataclass_names(val)

for union_type in union_types:
if union_type is type(None):
continue
if _matches_dataclass_union_type(union_type, val_type_name, val_dataclass_name):
return union_type
if _matches_non_dataclass_union_type(union_type, val):
return union_type
return None


def _convert_from_flatbuffer_value(val: Any, expected_type: Any) -> Any:
if val is None:
return None

origin = get_origin(expected_type)
if origin is list:
item_type = get_args(expected_type)[0]
return [_convert_from_flatbuffer_value(item, item_type) for item in val]

if origin is Union:
union_type = _union_choice_from_value(get_args(expected_type), val)
if union_type is None:
raise TypeError(
f"Could not match value type {type(val)} to {expected_type}"
)
if union_type is type(None):
return None
return _convert_from_flatbuffer_value(val, union_type)

if expected_type is bytes:
return _coerce_bytes(val)
if expected_type is str and isinstance(val, (bytes, bytearray, memoryview)):
return _coerce_bytes(val).decode("utf-8")
if is_dataclass(expected_type):
return _convert_from_flatbuffer_dataclass(val, expected_type)
if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum):
if isinstance(val, expected_type):
return val
return expected_type(val)
if isinstance(expected_type, type):
return expected_type(val)
return val


def _convert_from_flatbuffer_dataclass(val: Any, dataclass_type: type) -> Any:
result = {}
type_hints = _dataclass_type_hints(dataclass_type)
for src_name, dst_name in _dataclass_field_map(dataclass_type):
result[src_name] = _convert_from_flatbuffer_value(
getattr(val, dst_name), type_hints[src_name]
)
return dataclass_type(**result)


def _flatbuffer_to_program(program_data: bytes) -> Program:
_patch_generated_flatbuffer_aliases()
program_t = ProgramT.InitFromPackedBuf(program_data)
return _convert_from_flatbuffer_dataclass(program_t, Program)


@lru_cache(maxsize=1)
def _get_schema_info(
constant_tensor_alignment: Optional[int], delegate_alignment: Optional[int]
Expand All @@ -213,11 +346,7 @@ def _program_to_flatbuffer(
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> _FlatbufferResult:
"""Converts a Program dataclass into binary flatbuffer data.

Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke
flatc to build the binary.
"""
"""Converts a Program dataclass into binary flatbuffer data."""
schema_info = _get_schema_info(constant_tensor_alignment, delegate_alignment)
_set_pack_alignments(schema_info.tensor_alignment, schema_info.delegate_alignment)
_install_fast_packers()
Expand Down
20 changes: 6 additions & 14 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
from executorch.exir._serialize._flatbuffer import (
_FlatbufferResult,
_program_flatbuffer_to_json,
from executorch.exir._serialize._dataclass import _DataclassEncoder
from executorch.exir._serialize._flatbuffer import _FlatbufferResult
from executorch.exir._serialize._flatbuffer_program import (
_flatbuffer_to_program,
_program_to_flatbuffer,
)
from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer
from executorch.exir._serialize._named_data_store import (
NamedDataStore,
NamedDataStoreOutput,
Expand Down Expand Up @@ -86,12 +86,6 @@ def _program_to_json(program: Program) -> str:
return json.dumps(program, cls=_DataclassEncoder)


def _json_to_program(program_json: bytes) -> Program:
"""Returns a Program deserialized from the given JSON string."""
# construct program class recursively from dict
return _json_to_dataclass(json.loads(program_json), cls=Program)


def _insert_flatbuffer_header(
flatbuffer_data: bytes, magic_regex: str, header_data: bytes
) -> bytes:
Expand Down Expand Up @@ -757,9 +751,7 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile:
segment_base_offset = eh.segment_base_offset

# Parse the flatbuffer data.
program: Program = _json_to_program(
_program_flatbuffer_to_json(program_data[:program_size])
)
program: Program = _flatbuffer_to_program(program_data[:program_size])

if segment_base_offset != 0:
# Move segment data back into the Program.
Expand Down
Loading
Loading