diff --git a/backends/apple/coreml/BUCK b/backends/apple/coreml/BUCK index e97d63e9998..de0b58a4c39 100644 --- a/backends/apple/coreml/BUCK +++ b/backends/apple/coreml/BUCK @@ -64,6 +64,9 @@ runtime.cxx_library( "-Wno-receiver-expr", "-Wno-error", ], + preprocessor_flags = [ + "-DJSON_NOEXCEPTION", + ], define_static_target = True, header_namespace = "backends/apple/coreml", exported_headers = ["runtime/delegate/executorch_operations.h", "runtime/include/coreml_backend/delegate.h"], @@ -89,6 +92,7 @@ runtime.cxx_library( platforms = [APPLE], visibility = ["PUBLIC"], deps = [ + "fbsource//third-party/nlohmann-json:nlohmann-json", "//executorch/runtime/backend:interface", "//executorch/runtime/core:core", "//executorch/runtime/kernel:kernel_includes", diff --git a/backends/apple/coreml/CMakeLists.txt b/backends/apple/coreml/CMakeLists.txt index 17e2d94e336..60767ff8091 100644 --- a/backends/apple/coreml/CMakeLists.txt +++ b/backends/apple/coreml/CMakeLists.txt @@ -220,6 +220,14 @@ if(APPLE) target_link_libraries(coremldelegate PRIVATE libprotobuf-lite) endif() + # Add nlohmann_json include directory (header-only library) Define + # JSON_NOEXCEPTION since coremldelegate is compiled with -fno-exceptions + target_include_directories( + coremldelegate + PRIVATE ${PROJECT_SOURCE_DIR}/third-party/json/single_include + ) + target_compile_definitions(coremldelegate PRIVATE JSON_NOEXCEPTION) + target_link_libraries( coremldelegate PUBLIC coreml_util coreml_inmemoryfs diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index 32cd0df67a2..92a21c32d1c 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -13,7 +13,7 @@ from pathlib import Path -from typing import Any, Dict, final, List, Optional, Tuple +from typing import Any, Dict, final, List, Optional, Tuple, Union import coremltools as ct import coremltools.optimize as cto @@ -44,6 +44,15 @@ class COMPILE_SPEC_KEYS(Enum): OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config" ENUMERATED_SHAPES = "enumerated_shapes" PASS_PIPELINE = "pass_pipeline" + MULTIMETHOD_WEIGHT_SHARING_STRATEGY = "multimethod_weight_sharing_strategy" + + +class MULTIMETHOD_WEIGHT_SHARING_STRATEGY(Enum): + # Methods are processed independently with no weight sharing. + DISABLED = "disabled" + # Partitions must align positionally across methods; enables weight sharing + # via NamedDataStore. Raises an error if partition counts don't match. + POSITIONAL = "positional" class MODEL_PATHS(Enum): @@ -53,16 +62,34 @@ class MODEL_PATHS(Enum): DEBUG_INFO = "debug_info.json" +@dataclass +class MethodMetadata: + # The method input names. + inputNames: List[str] + # The method output names. + outputNames: List[str] + + @dataclass class ModelMetadata: - # The model input names. + # The model input names (for single-method models). inputNames: List[str] - # The model output names. + # The model output names (for single-method models). outputNames: List[str] # The model identifier. identifier: str +@dataclass +class MultifunctionModelMetadata: + # The model identifier. + identifier: str + # Per-method metadata (method name -> MethodMetadata). + methods: Dict[str, MethodMetadata] + # The default method name. + defaultMethod: str + + @dataclass class ModelDebugInfo: # Version info. @@ -248,6 +275,43 @@ def pass_pipeline_from_compile_specs( return ct.PassPipeline.DEFAULT + @staticmethod + def generate_multimethod_weight_sharing_strategy_compile_spec( + strategy: "MULTIMETHOD_WEIGHT_SHARING_STRATEGY", + ) -> CompileSpec: + """ + Returns the compile spec representing the multimethod weight sharing strategy. + + Args: + strategy: The weight sharing strategy to use when combining methods. + POSITIONAL: Partitions must align positionally across methods; enables + weight sharing via NamedDataStore. Raises error if partitions don't align. + DISABLED: Methods are processed independently with no weight sharing. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.MULTIMETHOD_WEIGHT_SHARING_STRATEGY.value, + strategy.value.encode("utf-8"), + ) + + @staticmethod + def multimethod_weight_sharing_strategy_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> "MULTIMETHOD_WEIGHT_SHARING_STRATEGY": + """ + Returns the multimethod weight sharing strategy by parsing the list of compile specs. + Defaults to POSITIONAL if not specified. + """ + for compile_spec in compile_specs: + if ( + compile_spec.key + == COMPILE_SPEC_KEYS.MULTIMETHOD_WEIGHT_SHARING_STRATEGY.value + ): + return MULTIMETHOD_WEIGHT_SHARING_STRATEGY( + compile_spec.value.decode("utf-8") + ) + + return MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + @staticmethod def generate_enumerated_shapes_compile_spec( ep: ExportedProgram, @@ -429,7 +493,10 @@ def get_model_debug_info(model_package_dir: Path) -> Optional[ModelDebugInfo]: ) @staticmethod - def save_model_metadata(model_metadata: ModelMetadata, model_dir_path: Path): + def save_model_metadata( + model_metadata: Union[ModelMetadata, MultifunctionModelMetadata], + model_dir_path: Path, + ): # Store model metadata. model_metadata_path = Path(model_dir_path) / MODEL_PATHS.METADATA.value model_metadata_json = json.dumps(asdict(model_metadata)) @@ -444,6 +511,84 @@ def save_model_debug_info(model_debug_info: ModelDebugInfo, model_dir_path: Path with open(model_debug_info_path, "w") as outfile: outfile.write(model_debug_info_json) + @staticmethod + def _convert_to_mlmodel( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + skip_model_load: bool = True, + ) -> ct.models.MLModel: + """ + Convert an ExportedProgram to a CoreML MLModel. + + Args: + edge_program: The edge program to convert + compile_specs: Compile specs for this conversion + skip_model_load: Whether to skip loading the model (for efficiency) + + Returns: + The converted MLModel + """ + model_compute_precision = ( + CoreMLBackend.model_compute_precision_from_compile_specs(compile_specs) + ) + minimum_deployment_target = ( + CoreMLBackend.min_deployment_target_from_compile_specs(compile_specs) + ) + compute_units = CoreMLBackend.compute_unit_from_compile_specs(compile_specs) + pass_pipeline = CoreMLBackend.pass_pipeline_from_compile_specs(compile_specs) + enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs( + compile_specs + ) + + # If using enumerated shapes, pass inputs explicitly to CoreML's convert() + ct_inputs = None + if enumerated_shapes is not None: + ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes) + + # Check there are not multiple enumerated inputs if iOS is below 18 + if (minimum_deployment_target is None) or ( + minimum_deployment_target < ct.target.iOS18 + ): + n_enumerated_inputs = sum( + 1 + for ct_in in ct_inputs + if isinstance(ct_in.shape, ct.EnumeratedShapes) + ) + if n_enumerated_inputs > 1: + raise ValueError( + f"Your program has {n_enumerated_inputs} enumerated inputs, " + f"but minimum_deployment_target is {minimum_deployment_target}. " + "Multiple enumerated inputs requires iOS18 or later." + ) + + mlmodel = ct.convert( + model=edge_program, + source="pytorch", + convert_to="mlprogram", + pass_pipeline=pass_pipeline, + skip_model_load=skip_model_load, + compute_precision=model_compute_precision, + minimum_deployment_target=minimum_deployment_target, + compute_units=compute_units, + inputs=ct_inputs, + ) + + # Apply quantization if specified + op_linear_quantizer_config = ( + CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs) + ) + if op_linear_quantizer_config is not None: + logger.warning( + "Core ML Backend op_linear_quantizer_config API is experimental" + ) + config = cto.coreml.OptimizationConfig( + global_config=op_linear_quantizer_config, + op_type_configs={"gather": None}, + ) + mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config=config) + + return mlmodel + @staticmethod def preprocess_model( mlmodel: ct.models.MLModel, model_type: MODEL_TYPE @@ -517,72 +662,287 @@ def preprocess( ) -> PreprocessResult: logger.info(f"Edge program: {edge_program}") model_type: CoreMLBackend.MODEL_TYPE = ( - CoreMLBackend.model_type_from_compile_specs( - compile_specs, - ) - ) - model_compute_precision: ct.precision = ( - CoreMLBackend.model_compute_precision_from_compile_specs(compile_specs) + CoreMLBackend.model_type_from_compile_specs(compile_specs) ) - minimum_deployment_target: Optional[ct.target] = ( - CoreMLBackend.min_deployment_target_from_compile_specs(compile_specs) - ) - compute_units: ct.ComputeUnit = CoreMLBackend.compute_unit_from_compile_specs( - compile_specs - ) - op_linear_quantizer_config = ( - CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs) - ) - enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs( - compile_specs - ) - pass_pipeline: ct.PassPipeline = CoreMLBackend.pass_pipeline_from_compile_specs( - compile_specs - ) - - # If using enumerated shapes, we need to pass the inputs to CoreML's convert() function - # explicitly - ct_inputs = None - if enumerated_shapes is not None: - ct_inputs = _get_ct_inputs(edge_program, enumerated_shapes) - - # Check there are not multiple enumerated inputs if iOS is below 18 - if (minimum_deployment_target is None) or ( - minimum_deployment_target < ct.target.iOS18 - ): - n_enumerated_inputs = 0 - for ct_in in ct_inputs: - if isinstance(ct_in.shape, ct.EnumeratedShapes): - n_enumerated_inputs += 1 - if n_enumerated_inputs > 1: - raise ValueError( - f"You're program has {n_enumerated_inputs}, but the minimum_deployment_target is set to {minimum_deployment_target}. Multiple enumerated inputs requires iOS18 or later." - ) # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because # get_compiled_model_path() requires a loaded model. skip_model_load = model_type != CoreMLBackend.MODEL_TYPE.COMPILED_MODEL - mlmodel = ct.convert( - model=edge_program, - source="pytorch", - convert_to="mlprogram", - pass_pipeline=pass_pipeline, - skip_model_load=skip_model_load, - compute_precision=model_compute_precision, - minimum_deployment_target=minimum_deployment_target, - compute_units=compute_units, - inputs=ct_inputs, + + mlmodel = CoreMLBackend._convert_to_mlmodel( + edge_program, compile_specs, skip_model_load=skip_model_load ) - if op_linear_quantizer_config is not None: - logger.warning( - "Core ML Backend op_linear_quantizer_config API is experimental" + return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type) + + @classmethod + def preprocess_multimethod( # noqa: C901 + cls, + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, List[PreprocessResult]]: + """ + Preprocess multiple methods, optionally combining them into CoreML multifunction models. + + The behavior is controlled by the MULTIMETHOD_WEIGHT_SHARING_STRATEGY compile spec: + + POSITIONAL (default): + Converts each method's ExportedPrograms to mlpackages, then combines + corresponding partitions across methods using CoreML's multifunction API + (ct.utils.save_multifunction). This enables weight sharing on disk between + methods (e.g., decode and prefill for LLMs). + + For each partition index, we create one multifunction model that combines + that partition from all methods. This requires all methods to have the same + number of partitions. Raises ValueError if partition counts don't match. + + To avoid duplication, we store the combined model ONCE in NamedDataStore + with a unique key. Each method's processed_bytes contains a JSON reference + to the model in NamedDataStore. + + DISABLED: + Each method is processed independently with no weight sharing. Falls back + to the default BackendDetails.preprocess_multimethod() implementation. + + Args: + edge_programs: Dictionary mapping method name to list of partitioned ExportedPrograms + compile_specs: Dictionary mapping method name to list of CompileSpecs for each partition. + The MULTIMETHOD_WEIGHT_SHARING_STRATEGY is read from the first method's first + partition compile specs. + + Returns: + Dictionary mapping method name to list of PreprocessResults. When using POSITIONAL + strategy, each method's processed_bytes contains a JSON reference to the shared + model in NamedDataStore. + """ + from executorch.exir._serialize._named_data_store import NamedDataStore + + method_names = list(edge_programs.keys()) + + if len(method_names) <= 1: + # Fall back to default implementation for single method + return super().preprocess_multimethod(edge_programs, compile_specs) + + # Get compile specs from the first method's first partition + first_method = method_names[0] + first_compile_specs = compile_specs[first_method][0] + + # Check the weight sharing strategy + weight_sharing_strategy = ( + cls.multimethod_weight_sharing_strategy_from_compile_specs( + first_compile_specs ) - config = cto.coreml.OptimizationConfig( - global_config=op_linear_quantizer_config, - # skip embedding - op_type_configs={"gather": None}, + ) + + if weight_sharing_strategy == MULTIMETHOD_WEIGHT_SHARING_STRATEGY.DISABLED: + # Process each method independently with no weight sharing + logger.info( + "Multimethod weight sharing is DISABLED. Processing methods independently." ) - mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config=config) + return super().preprocess_multimethod(edge_programs, compile_specs) + + assert weight_sharing_strategy == MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + + # POSITIONAL strategy: verify all methods have the same number of partitions + num_partitions = len(edge_programs[first_method]) + for method_name, programs in edge_programs.items(): + if len(programs) != num_partitions: + raise ValueError( + f"Method '{method_name}' has {len(programs)} partitions, but " + f"'{first_method}' has {num_partitions}. POSITIONAL weight sharing " + "strategy requires all methods to have the same number of partitions. " + "Use MULTIMETHOD_WEIGHT_SHARING_STRATEGY.DISABLED if methods should " + "be processed independently." + ) - return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type) + model_type: CoreMLBackend.MODEL_TYPE = cls.model_type_from_compile_specs( + first_compile_specs + ) + + # Create a temporary directory for all the mlpackages + temp_dir = Path(tempfile.mkdtemp()) + + # Structure: method_mlpackage_paths[method_name][partition_idx] = path + method_mlpackage_paths: Dict[str, List[Path]] = { + method_name: [] for method_name in method_names + } + + # Create a NamedDataStore to hold the shared multifunction models + named_data_store = NamedDataStore() + + try: + # Convert each method's partitions to mlpackages + for method_name in method_names: + for partition_idx, edge_program in enumerate( + edge_programs[method_name] + ): + method_compile_specs = compile_specs[method_name][partition_idx] + + logger.info( + f"Converting method '{method_name}' partition {partition_idx} to mlpackage..." + ) + + # Convert to CoreML using shared helper + mlmodel = cls._convert_to_mlmodel( + edge_program, method_compile_specs, skip_model_load=True + ) + + # Save the mlpackage + mlpackage_path = ( + temp_dir / f"{method_name}_partition_{partition_idx}.mlpackage" + ) + mlmodel.save(str(mlpackage_path)) + method_mlpackage_paths[method_name].append(mlpackage_path) + + logger.info( + f"Saved method '{method_name}' partition {partition_idx} to {mlpackage_path}" + ) + + # For each partition index, combine that partition from all methods + # into a single multifunction model. + # Store combined_processed_bytes[partition_idx] = bytes (for first method) + combined_processed_bytes: List[bytes] = [] + debug_handle_maps: List[Optional[Dict[str, Tuple[int]]]] = [] + model_keys: List[str] = [] # Keys for NamedDataStore lookup + + for partition_idx in range(num_partitions): + logger.info( + f"Combining partition {partition_idx} from all methods into multifunction model..." + ) + + desc = ct.utils.MultiFunctionDescriptor() + for method_name in method_names: + mlpackage_path = method_mlpackage_paths[method_name][partition_idx] + desc.add_function( + str(mlpackage_path), + src_function_name="main", + target_function_name=method_name, + ) + + # Set the first method as default + desc.default_function_name = first_method + + # Save the combined multifunction model for this partition + combined_path = ( + temp_dir / f"combined_partition_{partition_idx}.mlpackage" + ) + ct.utils.save_multifunction(desc, str(combined_path)) + + logger.info( + f"Saved combined multifunction model for partition {partition_idx} to {combined_path}" + ) + + # Create output directory for this partition's combined model + model_dir_path = temp_dir / f"lowered_module_partition_{partition_idx}" + model_dir_path.mkdir(exist_ok=True) + + # Handle model type (compiled vs mlpackage) + if model_type == CoreMLBackend.MODEL_TYPE.COMPILED_MODEL: + output_model_path = ( + model_dir_path / MODEL_PATHS.COMPILED_MODEL.value + ) + combined_model_loaded = ct.models.MLModel(str(combined_path)) + compiled_path = combined_model_loaded.get_compiled_model_path() + shutil.move(compiled_path, str(output_model_path)) + else: + output_model_path = model_dir_path / MODEL_PATHS.MODEL.value + shutil.copytree(str(combined_path), str(output_model_path)) + + # For multifunction models, we store all method metadata in a single file + # with method names as keys. Each method can have different input/output + # names (e.g., masks_1023 vs masks_992). + identifier = "executorch_" + str(uuid.uuid4()) + + # Extract metadata for each method + methods_metadata: Dict[str, MethodMetadata] = {} + for method_name in method_names: + method_mlpackage_path = method_mlpackage_paths[method_name][ + partition_idx + ] + method_model = ct.models.MLModel( + str(method_mlpackage_path), skip_model_load=True + ) + method_spec = method_model.get_spec() + input_names = [inp.name for inp in method_spec.description.input] + output_names = [out.name for out in method_spec.description.output] + methods_metadata[method_name] = MethodMetadata( + inputNames=input_names, + outputNames=output_names, + ) + logger.info( + f"Extracted metadata for method '{method_name}' partition {partition_idx}: " + f"{len(input_names)} inputs, {len(output_names)} outputs" + ) + + # Create consolidated multifunction metadata + multifunction_metadata = MultifunctionModelMetadata( + identifier=identifier, + methods={k: asdict(v) for k, v in methods_metadata.items()}, + defaultMethod=first_method, + ) + + # Save consolidated metadata + cls.save_model_metadata(multifunction_metadata, model_dir_path) + + # Note: Debug info is not supported for multifunction models. + # The combined model's debug mapping doesn't accurately map back to + # individual methods, so we skip it rather than provide incorrect info. + + # Flatten the model directory (with model + metadata) to bytes + processed_bytes = ( + executorchcoreml.flatten_directory_contents( + str(model_dir_path.resolve()) + ) + or b"" + ) + combined_processed_bytes.append(processed_bytes) + + # Store in NamedDataStore and save the key for later reference + model_key = f"coreml_{identifier}" + model_keys.append(model_key) + named_data_store.add_named_data(model_key, processed_bytes) + + logger.info( + f"Created combined processed bytes for partition {partition_idx} ({len(processed_bytes)} bytes)" + ) + logger.info(f"Stored in NamedDataStore with key '{model_key}'") + + # Debug handle map is not supported for multifunction models + debug_handle_maps.append(None) + + # Get the NamedDataStoreOutput to share across PreprocessResults + named_data_store_output = named_data_store.get_named_data_store_output() + + # Build PreprocessResults for each method and partition. + # All methods get a JSON reference to the model in NamedDataStore. + # The model is stored ONLY in NamedDataStore to avoid duplication. + # Runtime will detect the JSON reference and load from NamedDataMap. + preprocess_results: Dict[str, List[PreprocessResult]] = { + method_name: [] for method_name in method_names + } + + for partition_idx in range(num_partitions): + debug_handle_map = debug_handle_maps[partition_idx] + + for method_name in method_names: + # Create JSON reference for runtime to look up model in NamedDataStore + reference = { + "version": 1, + "key": model_keys[partition_idx], + "method": method_name, + } + reference_bytes = json.dumps(reference).encode("utf-8") + + preprocess_results[method_name].append( + PreprocessResult( + processed_bytes=reference_bytes, + debug_handle_map=debug_handle_map, + data_store_output=named_data_store_output, + ) + ) + + return preprocess_results + + finally: + # Clean up temporary directory + shutil.rmtree(str(temp_dir)) diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModel.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModel.mm index 4201293d1c5..5d06817d1ec 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModel.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModel.mm @@ -204,8 +204,26 @@ - (nullable instancetype)initWithAsset:(ETCoreMLAsset *)asset if (self) { _mlModel = mlModel; _asset = asset; - _orderedInputNames = [orderedInputNames copy]; - _orderedOutputNames = [orderedOutputNames copy]; + + // Use provided ordered names, or derive from model description as fallback + if (orderedInputNames != nil) { + _orderedInputNames = [orderedInputNames copy]; + } else { + // Derive input names from the model's description in sorted order for determinism + NSArray *inputKeys = mlModel.modelDescription.inputDescriptionsByName.allKeys; + NSArray *sortedInputKeys = [inputKeys sortedArrayUsingSelector:@selector(compare:)]; + _orderedInputNames = [NSMutableOrderedSet orderedSetWithArray:sortedInputKeys]; + } + + if (orderedOutputNames != nil) { + _orderedOutputNames = [orderedOutputNames copy]; + } else { + // Derive output names from the model's description in sorted order for determinism + NSArray *outputKeys = mlModel.modelDescription.outputDescriptionsByName.allKeys; + NSArray *sortedOutputKeys = [outputKeys sortedArrayUsingSelector:@selector(compare:)]; + _orderedOutputNames = [NSMutableOrderedSet orderedSetWithArray:sortedOutputKeys]; + } + _cache = [[NSCache alloc] init]; _inputConstraintsByName = get_multi_array_input_constraints_by_name(mlModel.modelDescription); _outputConstraintsByName = get_multi_array_output_constraints_by_name(mlModel.modelDescription); @@ -234,6 +252,15 @@ - (NSString *)identifier { BOOL lCopyData = copyData; NSString *argName = [nameEnumerator nextObject]; MLMultiArrayConstraint *constraint = argConstraintsByName[argName]; + + if (constraint == nil) { + ETCoreMLLogErrorAndSetNSError(error, + ETCoreMLErrorCorruptedModel, + "No constraint found for arg '%@'. Model may have mismatched input/output names.", + argName); + return nil; + } + const auto& layout = arg.layout(); auto dataType = to_ml_multiarray_data_type(layout.dataType()); MLMultiArray *multiArrayArg = nil; diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm index 731b8506f31..3dee02e9a9a 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelLoader.mm @@ -26,20 +26,6 @@ return result; } - - ETCoreMLModel * _Nullable get_model_from_asset(ETCoreMLAsset *asset, - MLModelConfiguration *configuration, - const executorchcoreml::ModelMetadata& metadata, - NSError * __autoreleasing *error) { - NSOrderedSet *orderedInputNames = ::get_ordered_set(metadata.input_names); - NSOrderedSet *orderedOutputNames = ::get_ordered_set(metadata.output_names); - ETCoreMLModel *model = [[ETCoreMLModel alloc] initWithAsset:asset - configuration:configuration - orderedInputNames:orderedInputNames - orderedOutputNames:orderedOutputNames - error:error]; - return model; - } } // namespace @implementation ETCoreMLModelLoader @@ -48,15 +34,22 @@ + (nullable ETCoreMLModel *)loadModelWithCompiledAsset:(ETCoreMLAsset *)compiled configuration:(MLModelConfiguration *)configuration metadata:(const executorchcoreml::ModelMetadata&)metadata error:(NSError * __autoreleasing *)error { - NSError *localError = nil; - ETCoreMLModel *model = (compiledAsset != nil) ? get_model_from_asset(compiledAsset, configuration, metadata, &localError) : nil; - if (model) { - return model; + if (compiledAsset == nil) { + return nil; } - if (error) { - *error = localError; - } - return nil; + + // Use the metadata's ordered input/output names. + // For multifunction models, the caller should load the per-method metadata + // which contains the correct input/output names for that method. + NSOrderedSet *orderedInputNames = ::get_ordered_set(metadata.input_names); + NSOrderedSet *orderedOutputNames = ::get_ordered_set(metadata.output_names); + + ETCoreMLModel *model = [[ETCoreMLModel alloc] initWithAsset:compiledAsset + configuration:configuration + orderedInputNames:orderedInputNames + orderedOutputNames:orderedOutputNames + error:error]; + return model; } diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.h b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.h index 9a9d45a037a..080ea7bb6e1 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.h +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.h @@ -50,6 +50,23 @@ __attribute__((objc_subclassing_restricted)) configuration:(MLModelConfiguration*)configuration error:(NSError* __autoreleasing*)error; +/// Loads the model from the AOT data with an optional method name for cache differentiation. +/// +/// The data is the AOT blob stored in the executorch Program. The method first parses the model +/// metadata stored in the blob and extracts the identifier. If a methodName is provided, it is +/// appended to the identifier to create separate cache entries for different ExecuTorch methods +/// that may share the same underlying partition but have different input shapes. +/// +/// @param data The AOT blob data. +/// @param configuration The model configuration that will be used to load the model. +/// @param methodName Optional method name (e.g., "forward", "prefill") for cache key differentiation. +/// @param error On failure, error is filled with the failure information. +/// @retval An opaque handle that points to the loaded model. +- (ModelHandle*)loadModelFromAOTData:(NSData*)data + configuration:(MLModelConfiguration*)configuration + methodName:(nullable NSString*)methodName + error:(NSError* __autoreleasing*)error; + /// Executes the loaded model. /// /// @param handle The handle to the loaded model. diff --git a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm index d59890ee00f..14077e8594e 100644 --- a/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm +++ b/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm @@ -211,6 +211,39 @@ void set_outputs(std::vector& outputs, return std::nullopt; } +std::optional get_model_metadata_for_method(const inmemoryfs::InMemoryFileSystem *inMemoryFS, + NSString *methodName) { + // Load the metadata.json file + auto metadata_opt = get_model_metadata(inMemoryFS); + if (!metadata_opt.has_value()) { + return std::nullopt; + } + + ModelMetadata& metadata = metadata_opt.value(); + + // If this is a multifunction model and a method name is provided, + // populate the top-level input_names/output_names from the method's metadata + if (metadata.is_multifunction() && methodName != nil && methodName.length > 0) { + std::string method_name_str = [methodName UTF8String]; + const MethodMetadata* method_metadata = metadata.get_method_metadata(method_name_str); + if (method_metadata != nullptr) { + metadata.input_names = method_metadata->input_names; + metadata.output_names = method_metadata->output_names; + } else { + // Method not found - fall back to default method if available + if (!metadata.default_method.empty()) { + const MethodMetadata* default_metadata = metadata.get_method_metadata(metadata.default_method); + if (default_metadata != nullptr) { + metadata.input_names = default_metadata->input_names; + metadata.output_names = default_metadata->output_names; + } + } + } + } + + return metadata; +} + NSOrderedSet *get_ordered_set(const std::vector& values) { NSMutableOrderedSet *result = [NSMutableOrderedSet orderedSetWithCapacity:values.size()]; for (const auto& value : values) { @@ -285,8 +318,13 @@ void set_outputs(std::vector& outputs, ETCoreMLModel * _Nullable get_model_from_asset(ETCoreMLAsset *asset, MLModelConfiguration *configuration, - const ModelMetadata& metadata, + const executorchcoreml::ModelMetadata& metadata, NSError * __autoreleasing *error) { + // Always use the metadata's ordered input/output names for consistency. + // The pytree flatten order during export determines the correct input order, + // and metadata captures this order. + // For multifunction models, all functions share the same input/output names + // (they differ only in shapes, which are handled by multiArrayConstraint). NSOrderedSet *orderedInputNames = ::get_ordered_set(metadata.input_names); NSOrderedSet *orderedOutputNames = ::get_ordered_set(metadata.output_names); ETCoreMLModel *model = [[ETCoreMLModel alloc] initWithAsset:asset @@ -322,6 +360,29 @@ void add_compute_unit(std::string& identifier, MLComputeUnits compute_units) { identifier.append(to_string(compute_units)); } +void add_function_name(std::string& identifier, MLModelConfiguration *configuration) { + // NOTE: For multifunction CoreML models, we intentionally do NOT include the + // function name in the cache key. The multifunction model should be compiled + // only once since it contains ALL functions. The functionName setting on + // MLModelConfiguration determines which function is invoked at runtime when + // creating the MLModel from the cached compiled files. + // + // Previously this added "_func_{name}" to the identifier, which caused + // redundant compilations (once per function). Now we compile once and reuse. + (void)identifier; + (void)configuration; +} + +void add_method_name(std::string& identifier, NSString *methodName) { + // NOTE: For multifunction CoreML models, we intentionally do NOT include the + // method name in the cache key. The multifunction model should be compiled + // only once and shared across all methods/functions. The functionName setting + // on MLModelConfiguration determines which function is invoked at runtime, + // but the compiled model is the same for all functions. + (void)identifier; + (void)methodName; +} + #if ET_EVENT_TRACER_ENABLED ETCoreMLAsset * _Nullable make_asset(NSURL *url, NSString *identifier, @@ -612,8 +673,9 @@ - (nullable ETCoreMLAsset *)modelAssetWithMetadata:(const ModelMetadata&)metadat - (nullable id)_modelExecutorWithAOTData:(NSData *)data - configuration:(MLModelConfiguration *)configuration - error:(NSError * __autoreleasing *)error { + configuration:(MLModelConfiguration *)configuration + methodName:(nullable NSString *)methodName + error:(NSError * __autoreleasing *)error { using namespace inmemoryfs; auto buffer = MemoryBuffer::make_unowned(const_cast(data.bytes), data.length); @@ -625,7 +687,9 @@ - (nullable ETCoreMLAsset *)modelAssetWithMetadata:(const ModelMetadata&)metadat return nil; } - std::optional metadata = ::get_model_metadata(inMemoryFS.get()); + // For multifunction models, try to load method-specific metadata first. + // This ensures we get the correct input/output names for this method. + std::optional metadata = ::get_model_metadata_for_method(inMemoryFS.get(), methodName); if (!metadata) { ETCoreMLLogErrorAndSetNSError(error, ETCoreMLErrorCorruptedMetadata, @@ -634,7 +698,32 @@ - (nullable ETCoreMLAsset *)modelAssetWithMetadata:(const ModelMetadata&)metadat } auto metadataValue = metadata.value(); + + // For multifunction CoreML models (ML Programs with multiple functions), + // we need to set functionName to select the correct function within the model. + // However, legacy single-function models require functionName to be nil. + // The metadata's "methods" field indicates if this is a multifunction model. + if (metadataValue.is_multifunction() && methodName != nil) { +#if defined(__IPHONE_18_0) || defined(__MAC_15_0) || defined(__TVOS_18_0) || defined(__WATCHOS_11_0) + if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) { + configuration.functionName = methodName; + } else { + ETCoreMLLogErrorAndSetNSError(error, + ETCoreMLErrorCorruptedModel, + "Multifunction CoreML models require iOS 18.0+ / macOS 15.0+."); + return nil; + } +#else + ETCoreMLLogErrorAndSetNSError(error, + ETCoreMLErrorCorruptedModel, + "Multifunction CoreML models require iOS 18.0+ / macOS 15.0+ SDK to build."); + return nil; +#endif + } + add_compute_unit(metadataValue.identifier, configuration.computeUnits); + add_function_name(metadataValue.identifier, configuration); + add_method_name(metadataValue.identifier, methodName); NSString *identifier = @(metadataValue.identifier.c_str()); // If there are multiple calls to load the same model, we only want to compile it once. __block id executor = nil; @@ -665,8 +754,19 @@ - (dispatch_queue_t)queueForLoadingModelWithIdentifier:(NSString *)identifier { - (ModelHandle *)loadModelFromAOTData:(NSData*)data configuration:(MLModelConfiguration*)configuration error:(NSError* __autoreleasing*)error { + return [self loadModelFromAOTData:data + configuration:configuration + methodName:nil + error:error]; +} + +- (ModelHandle *)loadModelFromAOTData:(NSData*)data + configuration:(MLModelConfiguration*)configuration + methodName:(nullable NSString*)methodName + error:(NSError* __autoreleasing*)error { id executor = [self _modelExecutorWithAOTData:data configuration:configuration + methodName:methodName error:error]; { os_unfair_lock_lock(&_lock); diff --git a/backends/apple/coreml/runtime/delegate/backend_delegate.h b/backends/apple/coreml/runtime/delegate/backend_delegate.h index 93c420e11d2..bd1c3da52a0 100644 --- a/backends/apple/coreml/runtime/delegate/backend_delegate.h +++ b/backends/apple/coreml/runtime/delegate/backend_delegate.h @@ -72,9 +72,12 @@ class BackendDelegate { /// /// @param processed The AOT blob. /// @param specs The specs at the time of compilation. + /// @param method_name The method name for multifunction model support (optional, may be nullptr). /// @retval An opaque handle to the initialized blob or `nullptr` if the /// initialization failed. - virtual Handle* init(Buffer processed, const std::unordered_map& specs) const noexcept = 0; + virtual Handle* init(Buffer processed, + const std::unordered_map& specs, + const char* method_name = nullptr) const noexcept = 0; /// Must execute the CoreML model with the specified handle. /// diff --git a/backends/apple/coreml/runtime/delegate/backend_delegate.mm b/backends/apple/coreml/runtime/delegate/backend_delegate.mm index 680c5c63143..598e4c96a75 100644 --- a/backends/apple/coreml/runtime/delegate/backend_delegate.mm +++ b/backends/apple/coreml/runtime/delegate/backend_delegate.mm @@ -33,10 +33,11 @@ MLComputeUnits get_compute_units(const Buffer& buffer) { } MLModelConfiguration *get_model_configuration(const std::unordered_map& specs) { - std::string key_name(ETCoreMLStrings.computeUnitsKeyName.UTF8String); + std::string compute_units_key(ETCoreMLStrings.computeUnitsKeyName.UTF8String); MLModelConfiguration *configuration = [[MLModelConfiguration alloc] init]; + for (const auto& [key, buffer] : specs) { - if (key == key_name) { + if (key == compute_units_key) { configuration.computeUnits = get_compute_units(buffer); break; } @@ -76,6 +77,11 @@ - (BOOL)loadAndReturnError:(NSError * _Nullable __autoreleasing *)error; - (void)loadAsynchronously; +- (ModelHandle*)loadModelFromAOTData:(NSData*)data + configuration:(MLModelConfiguration*)configuration + methodName:(nullable NSString*)methodName + error:(NSError* __autoreleasing*)error; + - (ModelHandle*)loadModelFromAOTData:(NSData*)data configuration:(MLModelConfiguration*)configuration error:(NSError* __autoreleasing*)error; @@ -161,12 +167,23 @@ - (void)loadAsynchronously { - (ModelHandle*)loadModelFromAOTData:(NSData*)data configuration:(MLModelConfiguration*)configuration error:(NSError* __autoreleasing*)error { + return [self loadModelFromAOTData:data + configuration:configuration + methodName:nil + error:error]; +} + +- (ModelHandle*)loadModelFromAOTData:(NSData*)data + configuration:(MLModelConfiguration*)configuration + methodName:(nullable NSString*)methodName + error:(NSError* __autoreleasing*)error { if (![self loadAndReturnError:error]) { return nil; } auto handle = [self.impl loadModelFromAOTData:data configuration:configuration + methodName:methodName error:error]; if ((handle != NULL) && self.config.should_prewarm_model) { [self.impl prewarmModelWithHandle:handle error:nil]; @@ -250,14 +267,24 @@ explicit BackendDelegateImpl(const Config& config) noexcept BackendDelegateImpl(BackendDelegateImpl const&) = delete; BackendDelegateImpl& operator=(BackendDelegateImpl const&) = delete; - Handle *init(Buffer processed,const std::unordered_map& specs) const noexcept override { +Handle *init(Buffer processed, + const std::unordered_map& specs, + const char* method_name = nullptr) const noexcept override { NSError *localError = nil; MLModelConfiguration *configuration = get_model_configuration(specs); + + NSString *methodNameStr = method_name ? @(method_name) : nil; + + // Note: For multifunction CoreML models, functionName is set in + // ETCoreMLModelManager::loadModelFromAOTData based on metadata.is_multifunction(). + // Legacy single-function models require functionName to remain nil. + NSData *data = [NSData dataWithBytesNoCopy:const_cast(processed.data()) length:processed.size() freeWhenDone:NO]; ModelHandle *modelHandle = [model_manager_ loadModelFromAOTData:data configuration:configuration + methodName:methodNameStr error:&localError]; if (localError != nil) { ETCoreMLLogError(localError, "Model init failed"); diff --git a/backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm b/backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm index 04a95e8a5a3..e143ded75c6 100644 --- a/backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm +++ b/backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm @@ -25,6 +25,8 @@ #import #import +#include + #ifdef ET_EVENT_TRACER_ENABLED #import #endif @@ -50,6 +52,57 @@ using executorch::aten::Tensor; using executorch::runtime::kTensorDimensionLimit; +using json = nlohmann::json; + +// Format identifier for JSON reference to NamedDataStore +constexpr const char* kNamedDataReferenceKeyField = "key"; + +/// Checks if the processed bytes represent a JSON reference to NamedDataStore. +/// The JSON format is: {"version": 1, "key": "...", "method": "..."} +/// +/// @param data Pointer to the processed bytes. +/// @param size Size of the processed bytes. +/// @return true if the bytes appear to be a JSON reference, false otherwise. +bool isNamedDataReference(const void* data, size_t size) { + // Quick check: JSON starts with '{' and should be small (< 512 bytes) + if (size < 2 || size > 512 || static_cast(data)[0] != '{') { + return false; + } + + // Try to parse as JSON and check for required fields + std::string_view content(static_cast(data), size); + json j = json::parse(content, nullptr, false); // false = don't throw on error + + if (j.is_discarded()) { + return false; + } + + // Check for required fields: "version" and "key" + return j.contains("version") && j.contains("key") && j["key"].is_string(); +} + +/// Parses the JSON reference and extracts the NamedDataStore key. +/// Expected format: {"version": 1, "key": "...", "method": "..."} +/// +/// @param data Pointer to the JSON bytes. +/// @param size Size of the JSON bytes. +/// @return The extracted key, or empty string if parsing fails. +std::string parseNamedDataKey(const void* data, size_t size) { + std::string_view content(static_cast(data), size); + json j = json::parse(content, nullptr, false); // false = don't throw on error + + if (j.is_discarded()) { + ET_LOG(Error, "Failed to parse JSON reference"); + return ""; + } + + if (j.contains("key") && j["key"].is_string()) { + return j["key"].get(); + } + + return ""; +} + std::optional get_data_type(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Bool: @@ -186,9 +239,54 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) { specs_map.emplace(spec.key, std::move(buffer)); } - auto buffer = Buffer(processed->data(), processed->size()); + Buffer buffer(nullptr, 0); // Will be set below + + // Check if processed bytes is a JSON reference to NamedDataStore + if (isNamedDataReference(processed->data(), processed->size())) { + // Parse the key from the JSON reference + std::string key = parseNamedDataKey(processed->data(), processed->size()); + ET_CHECK_OR_RETURN_ERROR(!key.empty(), + InvalidProgram, + "%s: Failed to parse NamedDataStore key from JSON reference.", + ETCoreMLStrings.delegateIdentifier.UTF8String); + + ET_LOG(Debug, "%s: Loading model from NamedDataStore with key: %s", + ETCoreMLStrings.delegateIdentifier.UTF8String, key.c_str()); + + // Get the NamedDataMap from context + const auto* named_data_map = context.get_named_data_map(); + ET_CHECK_OR_RETURN_ERROR(named_data_map != nullptr, + InvalidProgram, + "%s: NamedDataMap is null but processed bytes is a JSON reference.", + ETCoreMLStrings.delegateIdentifier.UTF8String); + + // Load the model data from NamedDataMap + auto result = named_data_map->get_data(key.c_str()); + ET_CHECK_OR_RETURN_ERROR(result.ok(), + InvalidProgram, + "%s: Failed to load model data from NamedDataStore with key: %s", + ETCoreMLStrings.delegateIdentifier.UTF8String, key.c_str()); + + // Move the result into the incoming FreeableBuffer so its lifetime matches `processed` + processed->~FreeableBuffer(); + new (processed) FreeableBuffer(std::move(result.get())); + buffer = Buffer(processed->data(), processed->size()); + + ET_LOG(Debug, "%s: Loaded %zu bytes from NamedDataStore", + ETCoreMLStrings.delegateIdentifier.UTF8String, processed->size()); + } else { + // Legacy path: use processed bytes directly + buffer = Buffer(processed->data(), processed->size()); + } + + // Get method name for multifunction model support + const char* method_name = context.get_method_name(); + if (method_name != nullptr) { + ET_LOG(Debug, "%s: Method name: %s", ETCoreMLStrings.delegateIdentifier.UTF8String, method_name); + } + std::error_code error; - auto handle = impl_->init(std::move(buffer), specs_map); + auto handle = impl_->init(std::move(buffer), specs_map, method_name); ET_CHECK_OR_RETURN_ERROR(handle != nullptr, InvalidProgram, "%s: Failed to init the model.", ETCoreMLStrings.delegateIdentifier.UTF8String); diff --git a/backends/apple/coreml/runtime/delegate/model_metadata.h b/backends/apple/coreml/runtime/delegate/model_metadata.h index 6b0f0807f9c..fec26f46648 100644 --- a/backends/apple/coreml/runtime/delegate/model_metadata.h +++ b/backends/apple/coreml/runtime/delegate/model_metadata.h @@ -7,6 +7,7 @@ #pragma once +#import #import #import @@ -14,9 +15,25 @@ namespace executorchcoreml { +/// A struct representing per-method metadata (for multifunction models). +struct MethodMetadata { + /// Constructs a `MethodMetadata` instance. + /// @param input_names The input names for the method. + /// @param output_names The output names for the method. + inline MethodMetadata(std::vector input_names, std::vector output_names) noexcept + : input_names(std::move(input_names)), output_names(std::move(output_names)) { } + + inline MethodMetadata() noexcept { } + + /// Input names of the method. + std::vector input_names; + /// Output names of the method. + std::vector output_names; +}; + /// A struct representing a model's metadata. struct ModelMetadata { - /// Constructs a `ModelMetada` instance. + /// Constructs a `ModelMetada` instance (for single-method models). /// @param identifier The unique identifier. /// @param input_names The input names for the model. /// @param output_names The output names for the model. @@ -29,7 +46,20 @@ struct ModelMetadata { inline ModelMetadata() noexcept { } /// Returns `true` if the metadata is valid otherwise `false`. - inline bool is_valid() const noexcept { return !identifier.empty() && !output_names.empty(); } + inline bool is_valid() const noexcept { return !identifier.empty() && (!output_names.empty() || !methods.empty()); } + + /// Returns `true` if this is multifunction metadata (has methods). + inline bool is_multifunction() const noexcept { return !methods.empty(); } + + /// Get metadata for a specific method. Returns nullptr if method not found. + /// For single-method models, returns nullptr (use input_names/output_names directly). + inline const MethodMetadata* get_method_metadata(const std::string& method_name) const noexcept { + auto it = methods.find(method_name); + if (it != methods.end()) { + return &it->second; + } + return nullptr; + } inline std::string to_json_string() const noexcept { return executorchcoreml::serde::json::to_json_string(*this); } @@ -39,9 +69,13 @@ struct ModelMetadata { /// Unique identifier. std::string identifier; - /// Input names of the model. + /// Input names of the model (for single-method models). std::vector input_names; - /// Output names of the model. + /// Output names of the model (for single-method models). std::vector output_names; + /// Per-method metadata (for multifunction models). + std::map methods; + /// Default method name (for multifunction models). + std::string default_method; }; } // namespace executorchcoreml diff --git a/backends/apple/coreml/runtime/delegate/serde_json.mm b/backends/apple/coreml/runtime/delegate/serde_json.mm index e39df4d734e..9e8e2f44a27 100644 --- a/backends/apple/coreml/runtime/delegate/serde_json.mm +++ b/backends/apple/coreml/runtime/delegate/serde_json.mm @@ -33,6 +33,8 @@ constexpr static std::string_view kIdentifierKey = "identifier"; constexpr static std::string_view kInputNamesKey = "inputNames"; constexpr static std::string_view kOutputNamesKey = "outputNames"; + constexpr static std::string_view kMethodsKey = "methods"; + constexpr static std::string_view kDefaultMethodKey = "defaultMethod"; }; } @@ -104,13 +106,46 @@ static void from_json(id json, executorchcoreml::Asset& asset) { } }; +template <> +struct Converter { + static id to_json(const executorchcoreml::MethodMetadata& method_metadata) { + return @{ + to_string(ModelMetadataKeys::kInputNamesKey) : to_json_value(method_metadata.input_names), + to_string(ModelMetadataKeys::kOutputNamesKey) : to_json_value(method_metadata.output_names) + }; + } + + static void from_json(id json, executorchcoreml::MethodMetadata& method_metadata) { + NSDictionary *json_dict = SAFE_CAST(json, NSDictionary); + if (!json_dict) { + return; + } + + from_json_value(json_dict[to_string(ModelMetadataKeys::kInputNamesKey)], method_metadata.input_names); + from_json_value(json_dict[to_string(ModelMetadataKeys::kOutputNamesKey)], method_metadata.output_names); + } +}; + template <> struct Converter { static id to_json(const executorchcoreml::ModelMetadata& metadata) { + // For multifunction models with methods, serialize the new format + if (!metadata.methods.empty()) { + NSMutableDictionary *methods_dict = [NSMutableDictionary dictionary]; + for (const auto& [method_name, method_metadata] : metadata.methods) { + methods_dict[to_json_value(method_name)] = Converter::to_json(method_metadata); + } + return @{ + to_string(ModelMetadataKeys::kIdentifierKey) : to_json_value(metadata.identifier), + to_string(ModelMetadataKeys::kMethodsKey) : methods_dict, + to_string(ModelMetadataKeys::kDefaultMethodKey) : to_json_value(metadata.default_method) + }; + } + // For single-method models, serialize the old format for backwards compatibility return @{ to_string(ModelMetadataKeys::kIdentifierKey) : to_json_value(metadata.identifier), to_string(ModelMetadataKeys::kInputNamesKey) : to_json_value(metadata.input_names), - to_string(ModelMetadataKeys::kOutputNamesKey) :to_json_value(metadata.output_names) + to_string(ModelMetadataKeys::kOutputNamesKey) : to_json_value(metadata.output_names) }; } @@ -121,8 +156,24 @@ static void from_json(id json, executorchcoreml::ModelMetadata& metadata) { } from_json_value(json_dict[to_string(ModelMetadataKeys::kIdentifierKey)], metadata.identifier); - from_json_value(json_dict[to_string(ModelMetadataKeys::kInputNamesKey)], metadata.input_names); - from_json_value(json_dict[to_string(ModelMetadataKeys::kOutputNamesKey)], metadata.output_names); + + // Check if this is a multifunction model (has "methods" key) + NSDictionary *methods_dict = SAFE_CAST(json_dict[to_string(ModelMetadataKeys::kMethodsKey)], NSDictionary); + if (methods_dict) { + // New multifunction format + from_json_value(json_dict[to_string(ModelMetadataKeys::kDefaultMethodKey)], metadata.default_method); + for (NSString *method_name in methods_dict) { + executorchcoreml::MethodMetadata method_metadata; + Converter::from_json(methods_dict[method_name], method_metadata); + std::string method_name_str; + from_json_value(method_name, method_name_str); + metadata.methods[method_name_str] = std::move(method_metadata); + } + } else { + // Old single-method format + from_json_value(json_dict[to_string(ModelMetadataKeys::kInputNamesKey)], metadata.input_names); + from_json_value(json_dict[to_string(ModelMetadataKeys::kOutputNamesKey)], metadata.output_names); + } } }; diff --git a/backends/apple/coreml/test/test_coreml_multifunction.py b/backends/apple/coreml/test/test_coreml_multifunction.py new file mode 100644 index 00000000000..f00d9f079cd --- /dev/null +++ b/backends/apple/coreml/test/test_coreml_multifunction.py @@ -0,0 +1,314 @@ +# Copyright © 2024 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +import sys +import unittest + +import coremltools as ct +import torch + +from executorch.backends.apple.coreml.compiler.coreml_preprocess import ( + CoreMLBackend, + MULTIMETHOD_WEIGHT_SHARING_STRATEGY, +) +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + + +def is_fbcode(): + return not hasattr(torch.version, "git_version") + + +_TEST_RUNTIME = (sys.platform == "darwin") and not is_fbcode() +if _TEST_RUNTIME: + from executorch.runtime import Runtime + + +class TestCoreMLMultifunction(unittest.TestCase): + """Tests for multifunction (multi-method) CoreML model export.""" + + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + + def _get_compile_specs(self, weight_sharing: bool = True): + """Get compile specs for multifunction models.""" + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + compute_unit=ct.ComputeUnit.CPU_ONLY, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + if weight_sharing: + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + ) + ) + return compile_specs + + def test_multifunction_simple_model(self): + """Test exporting a simple model with multiple methods (forward and prefill).""" + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel() + model.eval() + + # Create example inputs for two different sequence lengths + decode_inputs = (torch.randn(1, 1, 16),) # seqlen=1 + prefill_inputs = (torch.randn(1, 8, 16),) # seqlen=8 + + # Export both methods + decode_ep = torch.export.export(model, decode_inputs) + prefill_ep = torch.export.export(model, prefill_inputs) + + # Create partitioner with multifunction support + partitioner = CoreMLPartitioner( + compile_specs=self._get_compile_specs(weight_sharing=True), + ) + + # Lower to edge with multiple methods + edge_manager = to_edge_transform_and_lower( + {"forward": decode_ep, "prefill": prefill_ep}, + partitioner=[partitioner], + compile_config=self.edge_compile_config, + ) + + # Verify both methods exist + method_names = edge_manager.methods + self.assertIn("forward", method_names) + self.assertIn("prefill", method_names) + + # Convert to ExecuTorch + et_program = edge_manager.to_executorch() + + if _TEST_RUNTIME: + # Test runtime execution + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + + # Check both methods are available + available_methods = program.method_names + self.assertIn("forward", available_methods) + self.assertIn("prefill", available_methods) + + # Test forward (decode) method + forward_method = program.load_method("forward") + decode_output = forward_method.execute(decode_inputs) + expected_decode = model(*decode_inputs) + self.assertTrue( + torch.allclose(decode_output[0], expected_decode, atol=1e-4, rtol=1e-4) + ) + + # Test prefill method + prefill_method = program.load_method("prefill") + prefill_output = prefill_method.execute(prefill_inputs) + expected_prefill = model(*prefill_inputs) + self.assertTrue( + torch.allclose( + prefill_output[0], expected_prefill, atol=1e-4, rtol=1e-4 + ) + ) + + def test_multifunction_with_kv_cache(self): + """Test multifunction export with KV cache-like buffers.""" + + class ModelWithCache(torch.nn.Module): + def __init__(self, hidden_dim: int, cache_len: int): + super().__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.hidden_dim = hidden_dim + self.cache_len = cache_len + + def forward(self, x, cache): + # x: [batch, seqlen, hidden_dim] + # cache: [batch, cache_len, hidden_dim] + out = self.linear(x) + # Simple cache update simulation + new_cache = torch.cat([cache[:, 1:, :], out[:, -1:, :]], dim=1) + return out, new_cache + + hidden_dim = 16 + cache_len = 32 + model = ModelWithCache(hidden_dim, cache_len) + model.eval() + + # Decode: seqlen=1 + decode_inputs = ( + torch.randn(1, 1, hidden_dim), + torch.randn(1, cache_len, hidden_dim), + ) + + # Prefill: seqlen=8 + prefill_inputs = ( + torch.randn(1, 8, hidden_dim), + torch.randn(1, cache_len, hidden_dim), + ) + + decode_ep = torch.export.export(model, decode_inputs) + prefill_ep = torch.export.export(model, prefill_inputs) + + partitioner = CoreMLPartitioner( + compile_specs=self._get_compile_specs(weight_sharing=True), + ) + + edge_manager = to_edge_transform_and_lower( + {"forward": decode_ep, "prefill": prefill_ep}, + partitioner=[partitioner], + compile_config=self.edge_compile_config, + ) + + et_program = edge_manager.to_executorch() + + if _TEST_RUNTIME: + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + + # Test decode + forward_method = program.load_method("forward") + decode_output = forward_method.execute(decode_inputs) + expected_out, expected_cache = model(*decode_inputs) + self.assertTrue( + torch.allclose(decode_output[0], expected_out, atol=1e-4, rtol=1e-4) + ) + self.assertTrue( + torch.allclose(decode_output[1], expected_cache, atol=1e-4, rtol=1e-4) + ) + + # Test prefill + prefill_method = program.load_method("prefill") + prefill_output = prefill_method.execute(prefill_inputs) + expected_out, expected_cache = model(*prefill_inputs) + self.assertTrue( + torch.allclose(prefill_output[0], expected_out, atol=1e-4, rtol=1e-4) + ) + self.assertTrue( + torch.allclose(prefill_output[1], expected_cache, atol=1e-4, rtol=1e-4) + ) + + def test_multifunction_without_weight_sharing(self): + """Test multifunction export without weight sharing.""" + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel() + model.eval() + + decode_inputs = (torch.randn(1, 1, 16),) + prefill_inputs = (torch.randn(1, 8, 16),) + + decode_ep = torch.export.export(model, decode_inputs) + prefill_ep = torch.export.export(model, prefill_inputs) + + # Create partitioner without weight sharing + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT32, + compute_unit=ct.ComputeUnit.CPU_ONLY, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.DISABLED + ) + ) + + partitioner = CoreMLPartitioner(compile_specs=compile_specs) + + edge_manager = to_edge_transform_and_lower( + {"forward": decode_ep, "prefill": prefill_ep}, + partitioner=[partitioner], + compile_config=self.edge_compile_config, + ) + + method_names = edge_manager.methods + self.assertIn("forward", method_names) + self.assertIn("prefill", method_names) + + et_program = edge_manager.to_executorch() + + if _TEST_RUNTIME: + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + + forward_method = program.load_method("forward") + decode_output = forward_method.execute(decode_inputs) + expected_decode = model(*decode_inputs) + self.assertTrue( + torch.allclose(decode_output[0], expected_decode, atol=1e-4, rtol=1e-4) + ) + + def test_multifunction_with_constant_methods(self): + """Test multifunction export with constant methods (metadata).""" + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x) + + model = SimpleModel() + model.eval() + + decode_inputs = (torch.randn(1, 1, 16),) + prefill_inputs = (torch.randn(1, 8, 16),) + + decode_ep = torch.export.export(model, decode_inputs) + prefill_ep = torch.export.export(model, prefill_inputs) + + partitioner = CoreMLPartitioner( + compile_specs=self._get_compile_specs(weight_sharing=True), + ) + + # Add constant methods (metadata) + constant_methods = { + "vocab_size": 32000, + "hidden_dim": 16, + "decode_seqlen": 1, + "prefill_seqlen": 8, + } + + edge_manager = to_edge_transform_and_lower( + {"forward": decode_ep, "prefill": prefill_ep}, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=self.edge_compile_config, + ) + + et_program = edge_manager.to_executorch() + + if _TEST_RUNTIME: + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + + # Check all methods are available (executable + constant) + available_methods = program.method_names + self.assertIn("forward", available_methods) + self.assertIn("prefill", available_methods) + self.assertIn("vocab_size", available_methods) + self.assertIn("hidden_dim", available_methods) + self.assertIn("decode_seqlen", available_methods) + self.assertIn("prefill_seqlen", available_methods) + + +if __name__ == "__main__": + test_runner = TestCoreMLMultifunction() + test_runner.test_multifunction_simple_model() + test_runner.test_multifunction_with_kv_cache() + test_runner.test_multifunction_without_weight_sharing() + test_runner.test_multifunction_with_constant_methods() + print("All tests passed!") diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index a3fd8201414..e6a43325fa9 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -27,7 +27,10 @@ import torch.nn as nn import torch.utils._pytree as pytree -from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.compiler.coreml_preprocess import ( + CoreMLBackend, + MULTIMETHOD_WEIGHT_SHARING_STRATEGY, +) from executorch.backends.apple.coreml.partition import CoreMLPartitioner from executorch.examples.apple.coreml.llama.utils import ( replace_linear_with_split_linear, @@ -90,26 +93,41 @@ def forward(self, *args, **kwargs): def remove_graph_break_(edge_manager): + """Remove graph break ops from all methods in the edge manager.""" from executorch.exir.dialects._ops import ops as exir_ops - for n in edge_manager.exported_program().graph_module.graph.nodes: - if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: - n.replace_all_uses_with(n.args[0]) - edge_manager.exported_program().graph_module.graph.eliminate_dead_code() - - -def load_model(checkpoint_path: str, params_path: str, max_context_len: int): - """Load the model from checkpoint with static_mha attention type.""" + # Get all method names + method_names = edge_manager.methods + for method_name in method_names: + ep = edge_manager.exported_program(method_name) + for n in ep.graph_module.graph.nodes: + if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: + n.replace_all_uses_with(n.args[0]) + ep.graph_module.graph.eliminate_dead_code() + + +def load_model( + checkpoint_path: str, + params_path: str, + max_context_len: int, + generate_full_logits: bool = True, +): + """Load the model from checkpoint with static_mha attention type. + + Args: + checkpoint_path: Path to the model checkpoint (.pth) + params_path: Path to params.json + max_context_len: Maximum context length + generate_full_logits: If True, output logits for all tokens (needed for + lookahead decoding). If False, only output logits for the last token + (more efficient for standard autoregressive generation). + """ with open(params_path, "r") as f: params = json.loads(f.read()) - # TODO: to support lookahead decoding, the static model outputs - # full logits, but if we are not using lookahead decoding, we can have a - # more efficient model by setting generate_full_logits=False and supplying the last - # valid token args = ModelArgs( max_context_len=max_context_len, - generate_full_logits=True, + generate_full_logits=generate_full_logits, **params, ) args.attention_type = "static_mha" @@ -152,6 +170,55 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): return model, args +def _create_example_inputs(model_args, input_len, max_context_len, float_dtype): + """ + Create example inputs for a given input length. + + Args: + model_args: Model configuration arguments + input_len: Sequence length for this forward pass + max_context_len: Maximum context length + float_dtype: Float dtype (torch.float16 or torch.float32) + + Returns: + Tuple of (example_inputs, cache_len) where example_inputs is the tuple + expected by the model's forward method. + """ + cache_len = max_context_len - input_len + + mgr = StaticAttentionIOManager( + model_args, + input_len=input_len, + cache_lens=cache_len, + batch_size=1, + dtype=float_dtype, + style="smart_mask", + mask_val=float("-inf"), + ) + + options = { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[:input_len], + "freqs_sin_override": mgr.freqs_sin[:input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + } + + # When generate_full_logits=False, we need to pass last_valid_token_pos + # to tell the model which position's logits to output. + # This is the index of the last real token (before any padding). + if not model_args.generate_full_logits: + options["last_valid_token_pos"] = torch.tensor( + [input_len - 1], dtype=torch.long + ) + + example_inputs = ( + torch.zeros(1, input_len, dtype=torch.int32), + options, + ) + + return example_inputs, cache_len + + def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype): """ Generate metadata methods for the C++ runner. @@ -320,11 +387,29 @@ def main(): help="Disable graph breaks between transformer blocks", ) + # Export mode options + parser.add_argument( + "--multifunction", + action="store_true", + help="Export as multifunction model with separate prefill (seqlen=input_len) " + "and decode (seqlen=1) methods. Weight sharing is enabled across methods. " + "When disabled, exports a single-method model with fixed seqlen=input_len " + "and generate_full_logits=True for lookahead decoding support.", + ) + args = parser.parse_args() # Compute cache length - print("Quantization and datatype:") + print("Export mode:") + if args.multifunction: + print( + "\tMultifunction: separate prefill/decode graphs, generate_full_logits=False" + ) + else: + print("\tSingle method: fixed seqlen, generate_full_logits=True (lookahead)") + + print("\nQuantization and datatype:") print(f"\tEmbedding quantize: {args.embedding_quantize}") print(f"\tLinear quantize: {args.linear_quantize}") print(f"\tDtype: {args.dtype}") @@ -340,11 +425,15 @@ def main(): print(f"\tMax splits: {args.max_splits}") # Load model + # For multifunction: generate_full_logits=False (efficient, only last token) + # For single method: generate_full_logits=True (needed for lookahead decoding) + generate_full_logits = not args.multifunction print(f"\nLoading model from {args.checkpoint}...") model, model_args = load_model( args.checkpoint, args.params, args.max_context_len, + generate_full_logits=generate_full_logits, ) print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") @@ -413,73 +502,177 @@ def main(): model.layers[n_layers - 1], break_before=False ) - # Create IO manager and example inputs - mgr = StaticAttentionIOManager( - model_args, - input_len=args.input_len, - cache_lens=cache_len, - batch_size=1, - dtype=float_dtype, - style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner - mask_val=float("-inf"), - ) - example_inputs = ( - torch.zeros(1, args.input_len, dtype=torch.int32), - { - "masks": mgr.masks, - "freqs_cos_override": mgr.freqs_cos[: args.input_len], - "freqs_sin_override": mgr.freqs_sin[: args.input_len], - "in_cache_state": (mgr.k_caches, mgr.v_caches), - }, - ) + if args.multifunction: + # Multifunction mode: separate prefill and decode graphs with weight sharing + decode_input_len = 1 + prefill_input_len = args.input_len # default 32 - # Test eager execution - print("\nTesting eager execution...") - with torch.no_grad(): - model(*example_inputs) - print("Eager execution successful!") - - # Export the model - print("\nExporting model...") - ep = torch.export.export(model, example_inputs) - print("Export successful!") - print(ep) - - # Generate metadata for C++ runner - print("\nGenerating metadata for C++ runner...") - constant_methods = _get_metadata( - model_args, example_inputs, args.input_len, cache_len, float_dtype - ) + print(f"\nCreating example inputs for decode (seqlen={decode_input_len})...") + decode_inputs, decode_cache_len = _create_example_inputs( + model_args, decode_input_len, args.max_context_len, float_dtype + ) - # Setup CoreML partitioner - print("\nSetting up CoreML partitioner...") - compile_specs = CoreMLBackend.generate_compile_specs( - minimum_deployment_target=ct.target.iOS18, - compute_precision={ - torch.float16: ct.precision.FLOAT16, - torch.float32: ct.precision.FLOAT32, - }[float_dtype], - compute_unit=ct.ComputeUnit.CPU_AND_NE, - model_type=CoreMLBackend.MODEL_TYPE.MODEL, - ) - partitioner = CoreMLPartitioner( - compile_specs=compile_specs, - take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[], - ) + print(f"Creating example inputs for prefill (seqlen={prefill_input_len})...") + prefill_inputs, prefill_cache_len = _create_example_inputs( + model_args, prefill_input_len, args.max_context_len, float_dtype + ) - # Lower to edge with constant methods for C++ runner - print("\nLowering to edge...") - edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) - edge_manager = to_edge_transform_and_lower( - ep, - partitioner=[partitioner], - constant_methods=constant_methods, - compile_config=edge_compile_config, - ) + # Test eager execution for both + print("\nTesting eager execution (decode, seqlen=1)...") + with torch.no_grad(): + model(*decode_inputs) + print("Decode eager execution successful!") + + print(f"\nTesting eager execution (prefill, seqlen={prefill_input_len})...") + with torch.no_grad(): + model(*prefill_inputs) + print("Prefill eager execution successful!") + + # Export both graphs + print("\nExporting decode model (seqlen=1)...") + decode_ep = torch.export.export(model, decode_inputs) + print("Decode export successful!") + print(decode_ep) + + print(f"\nExporting prefill model (seqlen={prefill_input_len})...") + prefill_ep = torch.export.export(model, prefill_inputs) + print("Prefill export successful!") + print(prefill_ep) + + # Generate metadata for C++ runner + # constant_methods are shared across all methods, so we prefix method-specific + # metadata with the method name + print("\nGenerating metadata for C++ runner...") + decode_metadata = _get_metadata( + model_args, decode_inputs, decode_input_len, decode_cache_len, float_dtype + ) + prefill_metadata = _get_metadata( + model_args, + prefill_inputs, + prefill_input_len, + prefill_cache_len, + float_dtype, + ) + + # Combine metadata - shared values go without prefix, method-specific values get prefixed + constant_methods = { + # Shared metadata (same for both methods) + "vocab_size": decode_metadata["vocab_size"], + "head_dim": decode_metadata["head_dim"], + "n_heads_per_cache": decode_metadata["n_heads_per_cache"], + "freqs_cos": decode_metadata["freqs_cos"], + "freqs_sin": decode_metadata["freqs_sin"], + # Decode-specific metadata (forward method) + "decode_input_len": decode_metadata["forward_input_len"], + "decode_freqs_cos_input_index": decode_metadata["freqs_cos_input_index"], + "decode_freqs_sin_input_index": decode_metadata["freqs_sin_input_index"], + "decode_mask_specs": decode_metadata["mask_specs"], + "decode_kv_cache_specs": decode_metadata["kv_cache_specs"], + # Prefill-specific metadata + "prefill_input_len": prefill_metadata["forward_input_len"], + "prefill_freqs_cos_input_index": prefill_metadata["freqs_cos_input_index"], + "prefill_freqs_sin_input_index": prefill_metadata["freqs_sin_input_index"], + "prefill_mask_specs": prefill_metadata["mask_specs"], + "prefill_kv_cache_specs": prefill_metadata["kv_cache_specs"], + } + + # Setup CoreML partitioner with multimethod weight sharing + print("\nSetting up CoreML partitioner (multifunction with weight sharing)...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + compile_specs.append( + CoreMLBackend.generate_multimethod_weight_sharing_strategy_compile_spec( + MULTIMETHOD_WEIGHT_SHARING_STRATEGY.POSITIONAL + ) + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) + + # Lower to edge with both decode and prefill methods + print("\nLowering to edge (multi-method: decode + prefill)...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + + # Create multi-method edge manager with decode as "forward" and prefill as "prefill" + edge_manager = to_edge_transform_and_lower( + {"forward": decode_ep, "prefill": prefill_ep}, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) + + print("\nDelegated program (decode/forward):") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) + + print("\nDelegated program (prefill):") + print( + format_delegated_graph( + edge_manager.exported_program("prefill").graph_module + ) + ) + else: + # Single method mode: fixed seqlen with generate_full_logits=True for lookahead + print(f"\nCreating example inputs (seqlen={args.input_len})...") + example_inputs, example_cache_len = _create_example_inputs( + model_args, args.input_len, args.max_context_len, float_dtype + ) + + # Test eager execution + print("\nTesting eager execution...") + with torch.no_grad(): + model(*example_inputs) + print("Eager execution successful!") + + # Export the model + print("\nExporting model...") + ep = torch.export.export(model, example_inputs) + print("Export successful!") + print(ep) + + # Generate metadata for C++ runner + print("\nGenerating metadata for C++ runner...") + constant_methods = _get_metadata( + model_args, example_inputs, args.input_len, example_cache_len, float_dtype + ) + + # Setup CoreML partitioner + print("\nSetting up CoreML partitioner...") + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision={ + torch.float16: ct.precision.FLOAT16, + torch.float32: ct.precision.FLOAT32, + }[float_dtype], + compute_unit=ct.ComputeUnit.CPU_AND_NE, + model_type=CoreMLBackend.MODEL_TYPE.MODEL, + ) + partitioner = CoreMLPartitioner( + compile_specs=compile_specs, + take_over_mutable_buffer=False, + skip_ops_for_coreml_delegation=[], + ) + + # Lower to edge with constant methods for C++ runner + print("\nLowering to edge...") + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_manager = to_edge_transform_and_lower( + ep, + partitioner=[partitioner], + constant_methods=constant_methods, + compile_config=edge_compile_config, + ) - print("\nDelegated program:") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + print("\nDelegated program:") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) # Convert to ExecuTorch print("\nConverting to ExecuTorch...") diff --git a/examples/apple/coreml/llama/readme.md b/examples/apple/coreml/llama/readme.md index 46e9043a5fc..ae3852a7828 100644 --- a/examples/apple/coreml/llama/readme.md +++ b/examples/apple/coreml/llama/readme.md @@ -1,6 +1,10 @@ # ANE-friendly Llama models -To export a static, ANE-friendly model use: +The export script supports two modes: **single method** (default) and **multifunction**. + +## Single Method Export (Default) + +Exports a model with a fixed sequence length and `generate_full_logits=True`, which is required for lookahead decoding: ``` python export_static_llm_coreml.py \ @@ -23,11 +27,73 @@ python run_static_llm.py \ (Enabling lookahead decoding is optional, but does improve performance.) +## Multifunction Export + +Exports a model with separate prefill (seqlen=input_len) and decode (seqlen=1) methods. This mode enables weight sharing across methods and uses `generate_full_logits=False` for more efficient autoregressive generation: + +``` +python export_static_llm_coreml.py \ + --checkpoint /path/to/model.pth \ + --params /path/to/params.json \ + --output static_llm_coreml_multifunction.pte \ + --multifunction +``` + +To test the multifunction model in python: + +``` +python run_static_llm_multifunction.py \ + --model static_llm_coreml_multifunction.pte \ + --params /path/to/params.json \ + --tokenizer /path/to/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 +``` + +Key differences between the two modes: +* **Single method**: Uses fixed seqlen for both prefill and decode, outputs full logits (supports lookahead decoding) +* **Multifunction**: Separate optimized graphs for prefill (seqlen=input_len) and decode (seqlen=1), outputs only last token logits (more efficient for standard generation), enables CoreML multifunction weight sharing + +## Export Options + +### Model Paths +| Option | Default | Description | +|--------|---------|-------------| +| `-c`, `--checkpoint` | (required) | Path to model checkpoint (.pth) | +| `-p`, `--params` | (required) | Path to params.json | +| `-o`, `--output` | `model.pte` | Output filename for the .pte model | + +### Model Configuration +| Option | Default | Description | +|--------|---------|-------------| +| `--max_context_len` | 1024 | Maximum context length | +| `--input_len` | 32 | Input sequence length per forward pass. In multifunction mode, this is the prefill sequence length. | +| `--dtype` | `fp16` | Model dtype (`fp16` or `fp32`). The ANE requires fp16. | + +### Quantization Options +| Option | Default | Description | +|--------|---------|-------------| +| `-E`, `--embedding_quantize` | `8,0` | Embedding quantization: `,`, e.g., `4,32` for 4-bit with group size 32, or `8,0` for 8-bit per-channel | +| `--linear_quantize` | `c4w` | CoreML linear quantization: `b4w` (blockwise 4-bit) or `c4w` (channelwise 4-bit). The ANE requires channelwise. | + +### Linear Splitting Options +| Option | Default | Description | +|--------|---------|-------------| +| `--target_split_size` | 1024 | Split linear layers into chunks of this size (helps with ANE performance) | +| `--max_splits` | 8 | Maximum number of splits for linear layers | + +### Export Mode Options +| Option | Default | Description | +|--------|---------|-------------| +| `--multifunction` | disabled | Export as multifunction model with separate prefill (seqlen=input_len) and decode (seqlen=1) methods. Enables weight sharing across methods. | +| `--no_graph_breaks` | disabled | Disable graph breaks between transformer blocks. Graph breaks help improve ANE performance by keeping model pieces smaller. | + +## ANE Optimizations + The static model has several ANE optimizations, including: * Splitting linear layers for improved performance (controlled by target_split_size and max_splits args) * Splitting the pte into multiple Core ML pieces for improved performance (can be disabled with no_graph_breaks) -* Re-writing SDPA to avoid 5-D tensors to imporve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) - +* Re-writing SDPA to avoid 5-D tensors to improve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833) We are working on adding a C++ runner as well. diff --git a/examples/apple/coreml/llama/run_static_llm_multifunction.py b/examples/apple/coreml/llama/run_static_llm_multifunction.py new file mode 100644 index 00000000000..feecd387a91 --- /dev/null +++ b/examples/apple/coreml/llama/run_static_llm_multifunction.py @@ -0,0 +1,562 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run script for multifunction static attention Llama models exported with coreml_static_llama.py. + +This script tests multifunction CoreML models that have separate "forward" (decode) and +"prefill" methods sharing weights. + +Usage: + python run_static_llm_multifunction.py \ + --model $HOME/Desktop/multifunction_test.pte \ + --params $HOME/models/llama1b/params.json \ + --tokenizer $HOME/models/llama1b/tokenizer.model \ + --prompt "Once upon a time" \ + --max_new_tokens 100 +""" + +import argparse +import json +import time +from typing import Any, Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.utils._pytree as pytree + +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.runner.generation import next_token +from executorch.examples.models.llama.static_attention import StaticAttentionIOManager +from executorch.runtime import Runtime + + +class Tokenizer: + """Wrapper to support both SentencePiece and Tiktoken tokenizers.""" + + def __init__(self, model_path: str): + try: + print("Trying to load sentencepiece") + sp = spm.SentencePieceProcessor() + sp.load(model_path) + self.tokenizer = sp + self._is_sentencepiece = True + except Exception: + print("Trying to load tiktoken") + from executorch.examples.models.llama.tokenizer import tiktoken + + self.tokenizer = tiktoken.Tokenizer(model_path) + self._is_sentencepiece = False + + def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]: + if self._is_sentencepiece: + bos_string = "" if bos else "" + eos_string = "" if eos else "" + return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") + return self.tokenizer.encode(text, bos=bos, eos=eos) + + def decode(self, tokens: List[int]) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode(tokens) + return self.tokenizer.decode(tokens) + + def decode_token(self, token: int) -> str: + if self._is_sentencepiece: + return self.tokenizer.decode([token]) + try: + return self.tokenizer.decode_token(token) + except UnicodeDecodeError: + return f"<{token}>" + + @property + def stop_tokens(self) -> List[int]: + if self._is_sentencepiece: + return [self.tokenizer.eos_id()] + return self.tokenizer.stop_tokens + + +def create_pte_wrapper( + decode_method, + prefill_method, + prefill_mgr: "StaticAttentionIOManager", + decode_mgr: "StaticAttentionIOManager", + prefill_seq_len: int, + prefill_cache_len: int, + decode_cache_len: int, +): + """ + Create a wrapper function that adapts PTE execution to the interface + expected by StaticAttentionIOManager. + + This multifunction version selects between prefill and decode methods + based on the input sequence length. It also uses the appropriate + StaticAttentionIOManager for each method since they have different + cache lengths. + + The wrapper: + - Takes (tokens, options_dict) like the eager model + - Selects prefill or decode method based on token count + - Adapts the options to use the correct manager's cache structure + - Flattens inputs using pytree + - Executes the appropriate PTE method + - Reconstructs outputs to match eager model format: (logits, {"out_cache_state": (k_dict, v_dict)}) + """ + + # Get cache keys from the prefill manager (same structure as decode) + k_cache_keys = list(prefill_mgr.k_caches.keys()) + v_cache_keys = list(prefill_mgr.v_caches.keys()) + + # Timing accumulators + timing_stats = { + "flatten_time": 0.0, + "execute_time": 0.0, + "reconstruct_time": 0.0, + "detection_time": 0.0, + "options_build_time": 0.0, + "call_count": 0, + } + + def wrapper( + tokens: torch.Tensor, options: Dict[str, Any] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + import time as time_module + + timing_stats["call_count"] += 1 + + # TIME: Detection logic + t0 = time_module.perf_counter() + + # Detect actual sequence length BEFORE padding. + # StaticAttentionIOManager._run_once pads tokens with zeros on the right: + # tokens = F.pad(tokens, (0, self.input_len - n_tokens)) + # So for decode (1 actual token), positions 1+ are all zeros. + # For prefill (32 actual tokens), positions have real token values. + padded_seq_len = tokens.shape[1] + if padded_seq_len > 1 and (tokens[0, 1:] == 0).all(): + # Single token padded to prefill_seq_len - this is decode + actual_seq_len = 1 + else: + actual_seq_len = padded_seq_len + + # Select method and manager based on actual (pre-padding) sequence length + if actual_seq_len == prefill_seq_len: + method = prefill_method + mgr = prefill_mgr + # Use tokens and freqs as-is for prefill + adapted_tokens = tokens + adapted_freqs_cos = options["freqs_cos_override"] + adapted_freqs_sin = options["freqs_sin_override"] + # Use cache state as-is (prefill manager's cache size matches prefill method) + adapted_cache_state = options["in_cache_state"] + else: + method = decode_method + mgr = decode_mgr + # For decode, use tokens and freqs as-is (decode_mgr passes correct shapes) + # Note: decode_mgr.input_len=1, so tokens are NOT padded, just (1, 1) + adapted_tokens = tokens + adapted_freqs_cos = options["freqs_cos_override"] + adapted_freqs_sin = options["freqs_sin_override"] + # Use cache state as-is (decode_mgr's cache size matches decode method) + adapted_cache_state = options["in_cache_state"] + + t1 = time_module.perf_counter() + timing_stats["detection_time"] += t1 - t0 + + # TIME: Build options + t0 = time_module.perf_counter() + + # Build options with the correct mask and freqs for this method + adapted_options = { + "masks": mgr.masks, # Use correct manager's mask (has right shape) + "freqs_cos_override": adapted_freqs_cos, + "freqs_sin_override": adapted_freqs_sin, + "in_cache_state": adapted_cache_state, + } + + # Pass through last_valid_token_pos if present (needed for generate_full_logits=False) + if "last_valid_token_pos" in options: + adapted_options["last_valid_token_pos"] = options["last_valid_token_pos"] + + # Build the same input structure as during export + inputs = (adapted_tokens, adapted_options) + + t1 = time_module.perf_counter() + timing_stats["options_build_time"] += t1 - t0 + + # TIME: Flatten using pytree (same order as torch.export) + t0 = time_module.perf_counter() + flat_inputs, _ = pytree.tree_flatten(inputs) + t1 = time_module.perf_counter() + timing_stats["flatten_time"] += t1 - t0 + + # TIME: Execute PTE model + t0 = time_module.perf_counter() + outputs = method.execute(flat_inputs) + t1 = time_module.perf_counter() + timing_stats["execute_time"] += t1 - t0 + + # TIME: Reconstruct outputs + t0 = time_module.perf_counter() + + # First output is logits + logits = outputs[0] + + # Remaining outputs are k_cache updates then v_cache updates + num_layers = len(k_cache_keys) + k_updates = outputs[1 : 1 + num_layers] + v_updates = outputs[1 + num_layers : 1 + 2 * num_layers] + + # Reconstruct the output cache state dicts + k_cache_dict = dict(zip(k_cache_keys, k_updates)) + v_cache_dict = dict(zip(v_cache_keys, v_updates)) + + attn_updates = {"out_cache_state": (k_cache_dict, v_cache_dict)} + + t1 = time_module.perf_counter() + timing_stats["reconstruct_time"] += t1 - t0 + + return logits, attn_updates + + def print_timing_stats(): + n = timing_stats["call_count"] + if n > 0: + print(f"\n=== Wrapper Timing Stats ({n} calls) ===") + print( + f" Detection time: {timing_stats['detection_time']*1000:.2f}ms total, {timing_stats['detection_time']/n*1000:.4f}ms avg" + ) + print( + f" Options build: {timing_stats['options_build_time']*1000:.2f}ms total, {timing_stats['options_build_time']/n*1000:.4f}ms avg" + ) + print( + f" Flatten time: {timing_stats['flatten_time']*1000:.2f}ms total, {timing_stats['flatten_time']/n*1000:.4f}ms avg" + ) + print( + f" Execute time: {timing_stats['execute_time']*1000:.2f}ms total, {timing_stats['execute_time']/n*1000:.3f}ms avg" + ) + print( + f" Reconstruct time: {timing_stats['reconstruct_time']*1000:.2f}ms total, {timing_stats['reconstruct_time']/n*1000:.4f}ms avg" + ) + total = ( + timing_stats["detection_time"] + + timing_stats["options_build_time"] + + timing_stats["flatten_time"] + + timing_stats["execute_time"] + + timing_stats["reconstruct_time"] + ) + print( + f" Total wrapper: {total*1000:.2f}ms total, {total/n*1000:.3f}ms avg" + ) + print( + f" Execute is {timing_stats['execute_time']/total*100:.1f}% of wrapper time" + ) + expected_tps = 1000 / (timing_stats["execute_time"] / n * 1000) + print(f" Expected tok/s from execute alone: {expected_tps:.1f}") + + wrapper.print_timing_stats = print_timing_stats + wrapper.timing_stats = timing_stats + + return wrapper + + +def main(): + parser = argparse.ArgumentParser( + description="Run multifunction static attention Llama model" + ) + + parser.add_argument( + "-m", + "--model", + required=True, + help="Path to exported .pte model", + ) + parser.add_argument( + "-p", + "--params", + required=True, + help="Path to params.json", + ) + parser.add_argument( + "-t", + "--tokenizer", + required=True, + help="Path to tokenizer model", + ) + parser.add_argument( + "--prompt", + type=str, + default="Once upon a time,", + help="Input prompt", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.6, + help="Sampling temperature", + ) + parser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top-p (nucleus) sampling threshold", + ) + parser.add_argument( + "--input_len", + type=int, + default=32, + help="Input sequence length for prefill (must match export)", + ) + parser.add_argument( + "--max_context_len", + type=int, + default=1024, + help="Maximum context length (must match export)", + ) + parser.add_argument( + "--lookahead", + action="store_true", + help="Enable lookahead (speculative) decoding", + ) + parser.add_argument( + "--ngram_size", + type=int, + default=5, + help="N-gram size for lookahead decoding", + ) + parser.add_argument( + "--window_size", + type=int, + default=4, + help="Window size for lookahead decoding", + ) + parser.add_argument( + "--n_verifications", + type=int, + default=4, + help="Number of verification branches for lookahead decoding", + ) + + args = parser.parse_args() + + # Load tokenizer + tokenizer = Tokenizer(args.tokenizer) + + # Load model params + with open(args.params, "r") as f: + params = json.loads(f.read()) + + # Create model args + # Multifunction models use generate_full_logits=False (only last token logits) + model_args = ModelArgs( + max_context_len=args.max_context_len, + generate_full_logits=False, + **params, + ) + model_args.attention_type = "static_mha" + + print(f"Model config: {model_args.n_layers} layers, dim={model_args.dim}") + print(f"Max context length: {args.max_context_len}, Input length: {args.input_len}") + + # Calculate cache lengths for each method + # The export script uses: cache_len = max_context_len - input_len + # So for multifunction models: + # - prefill: input_len=64, cache_len=960 (total=1024) + # - decode: input_len=1, cache_len=1023 (total=1024) + prefill_input_len = args.input_len # e.g., 64 + prefill_cache_len = args.max_context_len - args.input_len # e.g., 960 + decode_input_len = 1 + decode_cache_len = args.max_context_len - decode_input_len # e.g., 1023 + + print(f"Prefill: input_len={prefill_input_len}, cache_len={prefill_cache_len}") + print(f"Decode: input_len={decode_input_len}, cache_len={decode_cache_len}") + + # Create StaticAttentionIOManager for prefill + # This manager handles the main prefill/decode loop state + prefill_mgr = StaticAttentionIOManager( + model_args, + input_len=prefill_input_len, + cache_lens=prefill_cache_len, + batch_size=1, + dtype=torch.float16, + style="smart_mask", + mask_val=float("-inf"), + ) + + # Create a separate decode manager with correct cache length for mask shapes + # This is needed because decode method expects mask shape (1, 1, 1024) + # which requires cache_len=1023 when input_len=1 + decode_mgr = StaticAttentionIOManager( + model_args, + input_len=decode_input_len, + cache_lens=decode_cache_len, + batch_size=1, + dtype=torch.float16, + style="smart_mask", + mask_val=float("-inf"), + ) + + # Load PTE model with multifunction support + print(f"Loading multifunction model from {args.model}...") + runtime = Runtime.get() + program = runtime.load_program(args.model) + + # List available methods + method_names = program.method_names + # Separate executable methods from constant methods (metadata) + executable_methods = {"forward", "prefill"} + actual_methods = executable_methods & method_names + constant_methods = method_names - executable_methods + print(f"Executable methods: {actual_methods}") + print(f"Metadata methods: {constant_methods}") + + # Check for expected multifunction methods + if "forward" not in method_names or "prefill" not in method_names: + print( + f"Warning: Expected 'forward' and 'prefill' methods, found: {method_names}" + ) + print("Falling back to single 'forward' method...") + decode_method = program.load_method("forward") + prefill_method = decode_method + else: + # Load both methods + print("Loading 'forward' (decode) method...") + decode_method = program.load_method("forward") + print("Loading 'prefill' method...") + prefill_method = program.load_method("prefill") + + decode_metadata = decode_method.metadata + print( + f"Decode method metadata: num_inputs={decode_metadata.num_inputs()}, num_outputs={decode_metadata.num_outputs()}" + ) + + prefill_metadata = prefill_method.metadata + print( + f"Prefill method metadata: num_inputs={prefill_metadata.num_inputs()}, num_outputs={prefill_metadata.num_outputs()}" + ) + + # Get cache keys in insertion order (NOT sorted alphabetically!) + # Pytree preserves dict insertion order in Python 3.7+ + # The caches are created in layer order (0, 1, 2, ..., n_layers-1) + # Note: cache keys are obtained inside create_pte_wrapper + + # Create wrapper function that adapts PTE to eager interface + # This wrapper will select between prefill and decode based on seq_len + model_fn = create_pte_wrapper( + decode_method, + prefill_method, + prefill_mgr, + decode_mgr, + prefill_input_len, + prefill_cache_len, + decode_cache_len, + ) + + # Encode prompt + prompt_tokens = tokenizer.encode(args.prompt, bos=True, eos=False) + print(f"\nPrompt: {args.prompt}") + print(f"Prompt tokens: {len(prompt_tokens)}") + print("-" * 50) + + # Reset manager (use prefill_mgr as the main state manager) + prefill_mgr.reset() + decode_mgr.reset() + + # Prefill using StaticAttentionIOManager.prefill + # This will call model_fn with seq_len=input_len, which selects the prefill method + print("Prefilling (using 'prefill' method)...", end=" ", flush=True) + start_time = time.time() + logits = prefill_mgr.prefill(model_fn, prompt_tokens) + prefill_time = time.time() - start_time + print(f"done in {prefill_time:.2f}s") + + # Get first token from prefill logits + # With generate_full_logits=False, logits is 2D [batch, vocab] + # With generate_full_logits=True, logits is 3D [batch, seq_len, vocab] + if logits.dim() == 2: + first_token = next_token(logits, args.temperature, args.top_p) + else: + first_token = next_token(logits[:, -1, :], args.temperature, args.top_p) + + # After prefill, copy the cache state from prefill_mgr to decode_mgr + # This is necessary because decode_mgr has larger caches (1023 vs 960) + # and we'll be using decode_mgr.decode() for generation + for key in prefill_mgr.k_caches: + src_k = prefill_mgr.k_caches[key] + src_v = prefill_mgr.v_caches[key] + # Copy to decode_mgr's larger cache + decode_mgr.k_caches[key][:, :, :prefill_cache_len, :] = src_k + decode_mgr.v_caches[key][:, :, :prefill_cache_len, :] = src_v + + # Sync the position counter + decode_mgr.pos = prefill_mgr.pos + + # Update decode_mgr's masks to reflect current position + # The mask needs to unmask the positions that have been filled by prefill + for mask in decode_mgr._masks.values(): + mask.reset() + mask.unmask(prefill_mgr.pos) + + # Decode using decode_mgr.decode() which will call model_fn + # The wrapper will detect seq_len=1 (after we unpad) and route to decode method + # Since we're using decode_mgr, the cache shapes will match + print(f"\n{args.prompt}", end="", flush=True) + print(tokenizer.decode_token(first_token), end="", flush=True) + + decode_start = time.time() + + if args.lookahead: + # Use lookahead (speculative) decoding + print( + f"\n[Using lookahead decoding: ngram={args.ngram_size}, window={args.window_size}, verifications={args.n_verifications}]" + ) + generated_tokens = decode_mgr.lookahead_decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + ngram_size=args.ngram_size, + window_size=args.window_size, + n_verifications=args.n_verifications, + stop_tokens=tokenizer.stop_tokens, + ) + else: + # Use standard autoregressive decoding (uses 'forward' method) + print("\n[Using 'forward' (decode) method for generation]") + generated_tokens = decode_mgr.decode( + model_fn, + first_token, + n=args.max_new_tokens - 1, # -1 because first_token counts + stop_tokens=tokenizer.stop_tokens, + ) + + # Print generated tokens (skip first as it's the init_token we already printed) + for token in generated_tokens[1:]: + if token in tokenizer.stop_tokens: + break + print(tokenizer.decode_token(token), end="", flush=True) + + decode_time = time.time() - decode_start + total_generated = len(generated_tokens) + tokens_per_sec = total_generated / decode_time if decode_time > 0 else 0 + + print("\n" + "-" * 50) + print(f"Prefill: {len(prompt_tokens)} tokens in {prefill_time:.2f}s") + print( + f"Decode: {total_generated} tokens in {decode_time:.2f}s ({tokens_per_sec:.2f} tok/s)" + ) + + # Print detailed timing breakdown + model_fn.print_timing_stats() + + print("\nMultifunction model test completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index 593a270186b..b7556a14ead 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import argparse +import json import os import shutil from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional from executorch.backends.apple.coreml import executorchcoreml from executorch.backends.apple.coreml.compiler import CoreMLBackend @@ -21,27 +22,75 @@ def extract_coreml_models(pte_data: bytes): - program = deserialize_pte_binary(pte_data).program + pte_file = deserialize_pte_binary(pte_data) + program = pte_file.program + + # Build a map from named_data keys to their data for multifunction model support. + # Multifunction models store a JSON reference in processed_bytes that points to + # the actual model data in named_data. + # After deserialization, pte_file.named_data is a NamedDataStoreOutput containing + # buffers and pte_data (key -> DataEntry mapping). + named_data_map: Dict[str, bytes] = {} + if pte_file.named_data is not None: + for key, data_entry in pte_file.named_data.pte_data.items(): + named_data_map[key] = pte_file.named_data.buffers[data_entry.buffer_index] + delegates: List[BackendDelegate] = sum( [execution_plan.delegates for execution_plan in program.execution_plan], [] ) coreml_delegates: List[BackendDelegate] = [ delegate for delegate in delegates if delegate.id == CoreMLBackend.__name__ ] + + # Track extracted models to avoid duplicates (multifunction models share partitions) + extracted_keys: set = set() model_index: int = 1 + for coreml_delegate in coreml_delegates: coreml_delegate_data: BackendDelegateDataReference = coreml_delegate.processed coreml_processed_bytes: Optional[bytes] = None + model_name: Optional[str] = None + match coreml_delegate_data.location: case DataLocation.INLINE: - coreml_processed_bytes = program.backend_delegate_data[ + raw_bytes = program.backend_delegate_data[ coreml_delegate_data.index ].data + # Check if this is a JSON reference to named_data (multifunction models) + try: + reference = json.loads(raw_bytes.decode("utf-8")) + if ( + isinstance(reference, dict) + and "version" in reference + and "key" in reference + ): + key = reference.get("key") + if key in extracted_keys: + # Already extracted this partition, skip + continue + if key in named_data_map: + coreml_processed_bytes = named_data_map[key] + model_name = key # Use the key as model name + extracted_keys.add(key) + else: + print( + f"Warning: Named data key '{key}' not found in program" + ) + continue + except (json.JSONDecodeError, UnicodeDecodeError): + # Not JSON, treat as raw model data (legacy format) + coreml_processed_bytes = raw_bytes + case _: AssertionError("The loaded Program must have inline data.") - model_name: str = f"model_{model_index}" + if coreml_processed_bytes is None: + continue + + if model_name is None: + model_name = f"model_{model_index}" + model_path: Path = Path() / "extracted_coreml_models" / model_name if model_path.exists(): shutil.rmtree(model_path.absolute())