From a8d6492cb8de8875262d221b1b5ad953923ad7aa Mon Sep 17 00:00:00 2001 From: Chizkiyahu Raful Date: Mon, 9 Mar 2026 11:31:22 +0200 Subject: [PATCH 1/2] exir: add flatbuffer-to-program reader This continues the work from https://github.com/pytorch/executorch/pull/17333. Change-Id: I35ac4cd5f6430ea89939453344c13e056b5c746c Signed-off-by: Chizkiyahu Raful --- exir/_serialize/_flatbuffer_program.py | 141 +++++++++++++++++- exir/_serialize/_program.py | 12 +- .../test/test_flatbuffer_program.py | 12 +- exir/_serialize/test/test_program.py | 88 ++++++++++- 4 files changed, 231 insertions(+), 22 deletions(-) diff --git a/exir/_serialize/_flatbuffer_program.py b/exir/_serialize/_flatbuffer_program.py index 4c1c315347a..cd742c8361d 100644 --- a/exir/_serialize/_flatbuffer_program.py +++ b/exir/_serialize/_flatbuffer_program.py @@ -8,12 +8,14 @@ import enum import functools import importlib +import pkgutil import tempfile from contextvars import ContextVar from dataclasses import fields, is_dataclass from functools import lru_cache -from typing import Any, Dict, Optional +from types import ModuleType +from typing import Any, Dict, get_args, get_origin, get_type_hints, Optional, Union import flatbuffers # pyre-ignore[21] from executorch.exir._serialize._flatbuffer import ( @@ -22,6 +24,7 @@ _prepare_schema, _SchemaInfo, ) +from executorch.exir._serialize.generated import executorch_flatbuffer as _generated_fb from executorch.exir._serialize.generated.executorch_flatbuffer import ( BackendDelegateInlineData as _BackendDelegateInlineData, Buffer as _Buffer, @@ -33,6 +36,7 @@ _T_CLASS_CACHE: Dict[type, type] = {} _FIELD_NAME_CACHE: Dict[type, tuple[tuple[str, str], ...]] = {} +_TYPE_HINTS_CACHE: Dict[type, Dict[str, Any]] = {} _BUFFER_ALIGNMENT: ContextVar[int] = ContextVar("_BUFFER_ALIGNMENT", default=1) _DELEGATE_ALIGNMENT: ContextVar[int] = ContextVar("_DELEGATE_ALIGNMENT", default=1) @@ -64,6 +68,15 @@ def _dataclass_field_map(dataclass_type: type) -> tuple[tuple[str, str], ...]: return mapping +def _dataclass_type_hints(dataclass_type: type) -> Dict[str, Any]: + cached = _TYPE_HINTS_CACHE.get(dataclass_type) + if cached is not None: + return cached + type_hints = get_type_hints(dataclass_type) + _TYPE_HINTS_CACHE[dataclass_type] = type_hints + return type_hints + + def _create_aligned_byte_vector(builder: Any, data: bytes, alignment: int) -> int: if not _is_valid_alignment(alignment): raise ValueError(f"Bad alignment {alignment}") @@ -194,6 +207,126 @@ def convert_program(val: Program) -> ProgramT: return _convert_dataclass(val) +# The generated FlatBuffer Python modules import child tables/unions as modules +# (for example, Program.ExecutionPlan becomes the ExecutionPlan module), but the +# unpacking helpers later expect those globals to be the corresponding classes. +# Rebind module globals like ExecutionPlan -> ExecutionPlan.ExecutionPlan so the +# generated InitFromObj()/InitFromPackedBuf() code can instantiate nested types. +def _patch_generated_module_aliases(module: ModuleType) -> None: + for name, maybe_module in vars(module).items(): + if not isinstance(maybe_module, ModuleType): + continue + maybe_class = getattr(maybe_module, name, None) + if isinstance(maybe_class, type): + setattr(module, name, maybe_class) + + +@lru_cache(maxsize=1) +def _patch_generated_flatbuffer_aliases() -> None: + package_name = _generated_fb.__name__ + for module_info in pkgutil.iter_modules(_generated_fb.__path__): + module = importlib.import_module(f"{package_name}.{module_info.name}") + _patch_generated_module_aliases(module) + + +def _flatbuffer_dataclass_names(val: Any) -> tuple[str, Optional[str]]: + val_type_name = type(val).__name__ + if val_type_name.endswith("T"): + return val_type_name, val_type_name[:-1] + return val_type_name, None + + +def _matches_dataclass_union_type( + union_type: Any, val_type_name: str, val_dataclass_name: Optional[str] +) -> bool: + if not is_dataclass(union_type): + return False + union_name = union_type.__name__ + return union_name == val_type_name or ( + val_dataclass_name is not None and union_name == val_dataclass_name + ) + + +def _matches_non_dataclass_union_type(union_type: Any, val: Any) -> bool: + if union_type is Any: + return True + if union_type is str and isinstance(val, (bytes, bytearray, memoryview)): + return True + union_origin = get_origin(union_type) + if union_origin is list and hasattr(val, "__iter__"): + return True + return isinstance(union_type, type) and isinstance(val, union_type) + + +def _union_choice_from_value(union_types: tuple[Any, ...], val: Any) -> Any: + if val is None: + for union_type in union_types: + if union_type is type(None): + return union_type + return None + + val_type_name, val_dataclass_name = _flatbuffer_dataclass_names(val) + + for union_type in union_types: + if union_type is type(None): + continue + if _matches_dataclass_union_type(union_type, val_type_name, val_dataclass_name): + return union_type + if _matches_non_dataclass_union_type(union_type, val): + return union_type + return None + + +def _convert_from_flatbuffer_value(val: Any, expected_type: Any) -> Any: + if val is None: + return None + + origin = get_origin(expected_type) + if origin is list: + item_type = get_args(expected_type)[0] + return [_convert_from_flatbuffer_value(item, item_type) for item in val] + + if origin is Union: + union_type = _union_choice_from_value(get_args(expected_type), val) + if union_type is None: + raise TypeError( + f"Could not match value type {type(val)} to {expected_type}" + ) + if union_type is type(None): + return None + return _convert_from_flatbuffer_value(val, union_type) + + if expected_type is bytes: + return _coerce_bytes(val) + if expected_type is str and isinstance(val, (bytes, bytearray, memoryview)): + return _coerce_bytes(val).decode("utf-8") + if is_dataclass(expected_type): + return _convert_from_flatbuffer_dataclass(val, expected_type) + if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum): + if isinstance(val, expected_type): + return val + return expected_type(val) + if isinstance(expected_type, type): + return expected_type(val) + return val + + +def _convert_from_flatbuffer_dataclass(val: Any, dataclass_type: type) -> Any: + result = {} + type_hints = _dataclass_type_hints(dataclass_type) + for src_name, dst_name in _dataclass_field_map(dataclass_type): + result[src_name] = _convert_from_flatbuffer_value( + getattr(val, dst_name), type_hints[src_name] + ) + return dataclass_type(**result) + + +def _flatbuffer_to_program(program_data: bytes) -> Program: + _patch_generated_flatbuffer_aliases() + program_t = ProgramT.InitFromPackedBuf(program_data) + return _convert_from_flatbuffer_dataclass(program_t, Program) + + @lru_cache(maxsize=1) def _get_schema_info( constant_tensor_alignment: Optional[int], delegate_alignment: Optional[int] @@ -213,11 +346,7 @@ def _program_to_flatbuffer( constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, ) -> _FlatbufferResult: - """Converts a Program dataclass into binary flatbuffer data. - - Unlike _program_json_to_flatbuffer(), this does not use JSON or invoke - flatc to build the binary. - """ + """Converts a Program dataclass into binary flatbuffer data.""" schema_info = _get_schema_info(constant_tensor_alignment, delegate_alignment) _set_pack_alignments(schema_info.tensor_alignment, schema_info.delegate_alignment) _install_fast_packers() diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index c0a4f3b795a..964b56998e6 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -17,11 +17,11 @@ from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass -from executorch.exir._serialize._flatbuffer import ( - _FlatbufferResult, - _program_flatbuffer_to_json, +from executorch.exir._serialize._flatbuffer import _FlatbufferResult +from executorch.exir._serialize._flatbuffer_program import ( + _flatbuffer_to_program, + _program_to_flatbuffer, ) -from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer from executorch.exir._serialize._named_data_store import ( NamedDataStore, NamedDataStoreOutput, @@ -757,9 +757,7 @@ def deserialize_pte_binary(program_data: bytes) -> PTEFile: segment_base_offset = eh.segment_base_offset # Parse the flatbuffer data. - program: Program = _json_to_program( - _program_flatbuffer_to_json(program_data[:program_size]) - ) + program: Program = _flatbuffer_to_program(program_data[:program_size]) if segment_base_offset != 0: # Move segment data back into the Program. diff --git a/exir/_serialize/test/test_flatbuffer_program.py b/exir/_serialize/test/test_flatbuffer_program.py index 05e05d4e610..0ba13842a62 100644 --- a/exir/_serialize/test/test_flatbuffer_program.py +++ b/exir/_serialize/test/test_flatbuffer_program.py @@ -11,7 +11,10 @@ _program_flatbuffer_to_json, _program_json_to_flatbuffer, ) -from executorch.exir._serialize._flatbuffer_program import _program_to_flatbuffer +from executorch.exir._serialize._flatbuffer_program import ( + _flatbuffer_to_program, + _program_to_flatbuffer, +) from executorch.exir._serialize._program import _json_to_program, _program_to_json from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.schema import ( @@ -172,6 +175,13 @@ def test_roundtrip_via_json(self) -> None: program2 = _json_to_program(_program_flatbuffer_to_json(result.data)) self.assertEqual(program2, program) + def test_roundtrip_via_direct_python(self) -> None: + program = self._make_program() + result = _program_to_flatbuffer( + program, constant_tensor_alignment=32, delegate_alignment=64 + ) + self.assertEqual(_flatbuffer_to_program(result.data), program) + def test_flatbuffer_paths_match(self) -> None: program = self._make_program() cases = [ diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index 46e8f020a0b..bb897476ac9 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -1,6 +1,7 @@ #!/usr/bin/env fbpython # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -16,12 +17,11 @@ from typing import Dict, List, Sequence -from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json +from executorch.exir._serialize._flatbuffer_program import _flatbuffer_to_program from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir._serialize._program import ( _ExtendedHeader, _get_extended_header, - _json_to_program, _program_to_json, deserialize_pte_binary, PTEFile, @@ -30,6 +30,8 @@ from executorch.exir._serialize.data_serializer import DataEntry from executorch.exir._serialize.padding import aligned_size +from executorch.exir.backend.compile_spec_schema import CompileSpec + from executorch.exir.schema import ( BackendDelegate, BackendDelegateDataReference, @@ -38,7 +40,15 @@ ContainerMetadata, DataLocation, DataSegment, + Double, + EValue, ExecutionPlan, + Frame, + FrameList, + FreeCall, + Instruction, + JumpFalseCall, + MoveCall, Program, SubsegmentOffsets, ) @@ -195,7 +205,7 @@ def constant_segment_with_tensor_alignment( self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # The constant tensor data should appear as the only segment. self.assertEqual(len(program_with_segments.segments), 1) @@ -465,6 +475,68 @@ def test_round_trip_no_header_no_segments(self) -> None: self.assertEqual(deserialized.mutable_data, None) self.assertEqual(deserialized.named_data, None) + def test_deserialize_pte_binary_with_rich_flatbuffer_types(self) -> None: + program = get_test_program() + plan = program.execution_plan[0] + plan.values.append(EValue(Double(float("inf")))) + plan.delegates.append( + BackendDelegate( + id="delegate0", + processed=BackendDelegateDataReference( + location=DataLocation.INLINE, + index=0, + ), + compile_specs=[CompileSpec(key="k", value=b"v")], + ) + ) + plan.chains[0].instructions.extend( + [ + Instruction(MoveCall(move_from=0, move_to=1)), + Instruction( + JumpFalseCall(cond_value_index=1, destination_instruction=0) + ), + Instruction(FreeCall(value_index=0)), + ] + ) + plan.chains[0].stacktrace = [ + FrameList( + items=[ + Frame( + filename="file.py", + lineno=idx + 1, + name="fn", + context="ctx", + ) + ] + ) + for idx, _ in enumerate(plan.chains[0].instructions) + ] + program.constant_buffer.append(Buffer(storage=b"abcd")) + program.backend_delegate_data.append( + BackendDelegateInlineData(data=b"delegate-data") + ) + + deserialized = deserialize_pte_binary( + bytes(serialize_pte_binary(PTEFile(program=program))) + ) + + self.assert_programs_equal(program, deserialized.program) + self.assertEqual(deserialized.mutable_data, None) + self.assertEqual(deserialized.named_data, None) + self.assertIsInstance(plan.values[-1].val, Double) + self.assertIsInstance( + deserialized.program.execution_plan[0].values[-1].val, + Double, + ) + self.assertEqual( + deserialized.program.execution_plan[0].values[-1].val.double_val, + "inf", + ) + self.assertEqual( + deserialized.program.execution_plan[0].delegates[0].compile_specs[0].value, + b"v", + ) + def test_round_trip_large_buffer_sizes(self) -> None: """Tests that when the non_const_buffer_sizes contains integers overflowing a signed/unsigned 32 bit integer, we can still serialize the @@ -499,7 +571,7 @@ def test_round_trip_no_segments_and_no_header(self) -> None: self.assertIsNone(eh) # Peek inside the flatbuffer data to confirm that there are no segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) self.assertEqual(program_with_segments.segments, []) # Convert back. @@ -565,7 +637,7 @@ def test_round_trip_with_segments(self) -> None: # this also implicity tests the case where we try parsing the entire # file with segment data following it, demonstrating that the extra data # doesn't upset the flatbuffer parsing path. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # The delegate blobs we added to the program should appear as segments. # The one empty blob should have been ignored, hence the `- 1`. @@ -662,7 +734,7 @@ def test_no_constants(self) -> None: self.assertEqual(program.segments, []) # Peek inside the actual flatbuffer data to see the segments. - flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data)) + flatbuffer_program = _flatbuffer_to_program(pte_data) # Constant buffer should be empty. self.assertEqual(len(flatbuffer_program.constant_buffer), 0) @@ -782,7 +854,7 @@ def test_constant_delegate_and_named_data_segments(self) -> None: self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # Segment table should contain a constant segment, the delegate blobs # and a named data segment. @@ -985,7 +1057,7 @@ def test_named_data_segments(self) -> None: self.assertGreater(eh.segment_data_size, 0) # Peek inside the actual flatbuffer data to see the named data segments. - program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) + program_with_segments = _flatbuffer_to_program(pte_data) # pyre-ignore Incompatible parameter type [6] self.assertEqual(len(program_with_segments.named_data), len(pte_named_data)) From 1a41600d023b9c3cda8155be46e9a5e92dd4e452 Mon Sep 17 00:00:00 2001 From: Chizkiyahu Raful Date: Mon, 9 Mar 2026 11:32:54 +0200 Subject: [PATCH 2/2] exir: remove JSON program conversion path Change-Id: Iebb6ff9151b76b352ef5dbb4d9bd23e2e622c326 Signed-off-by: Chizkiyahu Raful --- exir/_serialize/_flatbuffer.py | 104 ------------------ exir/_serialize/_program.py | 8 +- exir/_serialize/test/test_flatbuffer.py | 65 +---------- .../test/test_flatbuffer_program.py | 51 --------- 4 files changed, 2 insertions(+), 226 deletions(-) diff --git a/exir/_serialize/_flatbuffer.py b/exir/_serialize/_flatbuffer.py index 77d0d073907..1f11c5c37c3 100644 --- a/exir/_serialize/_flatbuffer.py +++ b/exir/_serialize/_flatbuffer.py @@ -10,11 +10,8 @@ import importlib.resources import os import re -import shutil import subprocess -import tempfile - from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Sequence @@ -313,104 +310,3 @@ def _flatc_decompile( bin_path, ] ) - - -def _program_json_to_flatbuffer( - program_json: str, - *, - constant_tensor_alignment: Optional[int] = None, - delegate_alignment: Optional[int] = None, -) -> _FlatbufferResult: - """Converts Program-compatible JSON into binary flatbuffer data. - - Args: - program_json: The JSON to convert. Must be compatible with the root - table type of //executorch/schema/program.fbs. - constant_tensor_alignment: If provided, the alignment to use for tensor - data embedded in the output flatbuffer data. If not provided, uses - the alignment in the schema. - delegate_alignment: If provided, the alignment to use for delegate - data embedded in the output flatbuffer data. If not provided, uses - the alignment in the schema. - - Returns: The flatbuffer data and associated metadata. - """ - with tempfile.TemporaryDirectory() as temp_dir: - schema_info = _prepare_schema( - out_dir=temp_dir, - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - file_stem = "data" - json_path = os.path.join(temp_dir, file_stem + ".json") - output_path = os.path.join(temp_dir, file_stem + ".pte") - - with open(json_path, "wb") as json_file: - json_file.write(program_json.encode("ascii")) - - try: - _flatc_compile(temp_dir, schema_info.root_path, json_path) - except Exception as err: - # It's helpful to save the breaking files for debugging. Optionally - # move them out of the auto-deleting temporary directory. Don't do - # this by default because some input files can be many GB in size, - # and these copies won't be auto-deleted. - should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"} - extra_message = "" - if should_save: - try: - saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-") - for f in os.listdir(temp_dir): - shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir) - extra_message += f" Moved input files to '{saved_dir}'." - except Exception as err2: - extra_message += ( - f" (Failed to save input files for debugging: {err2})" - ) - else: - extra_message += ( - f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure." - ) - - raise RuntimeError( - f"Failed to compile {json_path} to {output_path}." + extra_message - ) from err - with open(output_path, "rb") as output_file: - return _FlatbufferResult( - data=output_file.read(), max_alignment=schema_info.max_alignment - ) - - -def _replace_infinity_in_json_file(content: bytes) -> bytes: - """Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs - is used to convert from flatbuffer to JSON. +-inf float values are not - supported by JSON, so we replace them with the string equivalent. When - converting from JSON to python dataclasses, the string is read as a Union - of float and string (see schema.py). - """ - content = re.sub( - rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content - ) - return content - - -def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes: - """Converts binary flatbuffer data into Program-compatible JSON. - - The binary is parsed using the schema in //executorch/schema/program.fbs. - """ - with tempfile.TemporaryDirectory() as temp_dir: - # No need to patch the alignment when reading. "force_align" is only - # used during serialization. - schema_info = _prepare_schema(temp_dir) - file_stem = "data" - bin_path = os.path.join(temp_dir, file_stem + ".bin") - json_path = os.path.join(temp_dir, file_stem + ".json") - - with open(bin_path, "wb") as bin_file: - bin_file.write(program_flatbuffer) - - _flatc_decompile(temp_dir, schema_info.root_path, bin_path) - with open(json_path, "rb") as output_file: - json_data = output_file.read() - return _replace_infinity_in_json_file(json_data) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 964b56998e6..235d27864ca 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -16,7 +16,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Tuple from executorch.exir._serialize._cord import Cord -from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass +from executorch.exir._serialize._dataclass import _DataclassEncoder from executorch.exir._serialize._flatbuffer import _FlatbufferResult from executorch.exir._serialize._flatbuffer_program import ( _flatbuffer_to_program, @@ -86,12 +86,6 @@ def _program_to_json(program: Program) -> str: return json.dumps(program, cls=_DataclassEncoder) -def _json_to_program(program_json: bytes) -> Program: - """Returns a Program deserialized from the given JSON string.""" - # construct program class recursively from dict - return _json_to_dataclass(json.loads(program_json), cls=Program) - - def _insert_flatbuffer_header( flatbuffer_data: bytes, magic_regex: str, header_data: bytes ) -> bytes: diff --git a/exir/_serialize/test/test_flatbuffer.py b/exir/_serialize/test/test_flatbuffer.py index 801ddca112d..e623da55cd2 100644 --- a/exir/_serialize/test/test_flatbuffer.py +++ b/exir/_serialize/test/test_flatbuffer.py @@ -7,19 +7,13 @@ # LICENSE file in the root directory of this source tree. import os -import re -import shutil import tempfile import unittest from typing import Dict, Optional, Sequence from unittest.mock import patch from executorch.exir._serialize import _flatbuffer -from executorch.exir._serialize._flatbuffer import ( - _program_json_to_flatbuffer, - _ResourceFiles, - _SchemaInfo, -) +from executorch.exir._serialize._flatbuffer import _ResourceFiles, _SchemaInfo def read_file(dir: str, filename: str) -> bytes: @@ -277,60 +271,3 @@ def test_bad_delegate_alignment_fails(self) -> None: out_dir, delegate_alignment=bad_alignment, ) - - -class TestProgramJsonToFlatbuffer(unittest.TestCase): - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) - def test_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - match = re.search(r"Moved input files to '(.*?)'", err_msg) - self.assertTrue(match, msg=f"Unexpected error message: {err_msg}") - path = match.group(1) - - files = frozenset(os.listdir(path)) - # Delete the files otherwise they'll accumulate every time the - # test is run. - shutil.rmtree(path) - # Check for a couple of the files that should be there. - self.assertIn("data.json", files) - self.assertIn("program.fbs", files) - - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) - def test_unable_to_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - with patch.object( - _flatbuffer.shutil, - "move", - side_effect=Exception("shutil.move mock failure"), - ): - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - self.assertIn("Failed to save input files", err_msg) - - @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: ""}) - def test_no_save_json_on_failure(self) -> None: - err_msg: Optional[str] = None - try: - _program_json_to_flatbuffer("} some bad json {") - self.fail("Should have raised an exception") - except RuntimeError as err: - err_msg = err.args[0] - - self.assertIsNotNone(err_msg) - self.assertIn( - f"Set {_flatbuffer._SAVE_FLATC_ENV}=1 to save input files", err_msg - ) - self.assertNotIn("Moved input files", err_msg) - self.assertNotIn("Failed to save input files", err_msg) diff --git a/exir/_serialize/test/test_flatbuffer_program.py b/exir/_serialize/test/test_flatbuffer_program.py index 0ba13842a62..4910f9b431f 100644 --- a/exir/_serialize/test/test_flatbuffer_program.py +++ b/exir/_serialize/test/test_flatbuffer_program.py @@ -4,18 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import json import unittest -from executorch.exir._serialize._flatbuffer import ( - _program_flatbuffer_to_json, - _program_json_to_flatbuffer, -) from executorch.exir._serialize._flatbuffer_program import ( _flatbuffer_to_program, _program_to_flatbuffer, ) -from executorch.exir._serialize._program import _json_to_program, _program_to_json from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.schema import ( AllocationDetails, @@ -160,21 +154,6 @@ def _make_program(self) -> Program: named_data=[], ) - def _flatbuffer_to_dict(self, flatbuffer_data: bytes) -> dict: - return json.loads(_program_flatbuffer_to_json(flatbuffer_data)) - - def test_roundtrip_via_json(self) -> None: - program = self._make_program() - result = _program_to_flatbuffer( - program, constant_tensor_alignment=32, delegate_alignment=64 - ) - self.assertGreater(len(result.data), 8) - self.assertEqual(result.data[4:6], b"ET") - self.assertGreaterEqual(result.max_alignment, 64) - - program2 = _json_to_program(_program_flatbuffer_to_json(result.data)) - self.assertEqual(program2, program) - def test_roundtrip_via_direct_python(self) -> None: program = self._make_program() result = _program_to_flatbuffer( @@ -182,36 +161,6 @@ def test_roundtrip_via_direct_python(self) -> None: ) self.assertEqual(_flatbuffer_to_program(result.data), program) - def test_flatbuffer_paths_match(self) -> None: - program = self._make_program() - cases = [ - (None, None), - (32, 64), - ] - for constant_tensor_alignment, delegate_alignment in cases: - with self.subTest( - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ): - result = _program_to_flatbuffer( - program, - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - result2 = _program_json_to_flatbuffer( - _program_to_json(program), - constant_tensor_alignment=constant_tensor_alignment, - delegate_alignment=delegate_alignment, - ) - direct_dict = self._flatbuffer_to_dict(result.data) - json_path_dict = self._flatbuffer_to_dict(result2.data) - self.assertEqual( - direct_dict, - json_path_dict, - "Flatbuffer JSON differs between direct and JSON paths", - ) - self.assertEqual(result.max_alignment, result2.max_alignment) - def test_bad_alignment_fails(self) -> None: program = Program( version=0,