Arm backend: Make aot_arm_compiler.py functions importable#18039
Arm backend: Make aot_arm_compiler.py functions importable#18039martinlsm wants to merge 1 commit intopytorch:mainfrom
Conversation
The model evaluation feature, i.e. to compute a model's top-1/top-5 accuracy, is to be moved into a new Python program. Some functions in aot_arm_compiler will then be needed to be imported by the evaluation program (to get a similar compilation flow and to reuse existing and working code). Make the functions in aot_arm_compiler "importable" by no longer passing in program args to them; the importing program will not have any such args. Also consistently categorize public/private functions and variables with the "leading underscore convention". Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Iaf2a2661956fe1d122d6ef2bbe023e6b8ddc501c
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18039
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Awaiting Approval, 9 New Failures, 1 Unrelated FailureAs of commit 63b46ff with merge base e458023 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label ciflow/trunk |
|
To add these label(s) (ciflow/trunk) to the PR, please first approve the workflows that are awaiting approval (scroll to the bottom of this page). This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
@pytorchbot label "partner: arm" |
|
@pytorchbot label "release notes: none" |
There was a problem hiding this comment.
Pull request overview
This PR refactors examples/arm/aot_arm_compiler.py to make its public functions importable by an upcoming model evaluation program. It replaces the args namespace parameter with explicit individual parameters in several functions, introduces a QuantMode enum to replace the boolean is_int16x8 flag, and uses the leading underscore convention to categorize functions as public (importable) or private (internal).
Changes:
- Introduced a
QuantModeenum (INT8,A16W8) to replace theis_int16x8boolean flag, and updatedquantize()andquantize_model()to use it. - Renamed internal functions (
get_args,get_compile_spec,save_bpte_program,to_edge_TOSA_delegate,to_edge_no_delegate,to_edge_cortex_m) with leading underscores to mark them private, while keeping importable functions (quantize,quantize_model,get_model_and_inputs_from_name,dump_delegation_info) public. - Refactored
_to_edge_TOSA_delegate,_to_edge_no_delegate, andquantize_modelto accept explicit parameters instead of theargsnamespace, and movedcompile_specconstruction andquant_modedetermination into the__main__block.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _to_edge_no_delegate( | ||
| exported_program: ExportedProgram, | ||
| args, | ||
| compile_spec, | ||
| model: GraphModule, | ||
| quant_mode: Optional[QuantMode], | ||
| example_inputs: Tuple[torch.Tensor], | ||
| model_name: str, | ||
| strict_export: bool, | ||
| ): | ||
| model_quant = None | ||
| if args.quantize: | ||
| if quant_mode is not None: | ||
| # As we can target multiple output encodings, one must | ||
| # be specified. | ||
| compile_spec = get_compile_spec(args) | ||
| model, exported_program = quantize_model( | ||
| args, model, example_inputs, compile_spec | ||
| model, | ||
| example_inputs, | ||
| compile_spec, | ||
| model_name, | ||
| strict_export, | ||
| quant_mode, | ||
| ) | ||
| model_quant = model | ||
|
|
There was a problem hiding this comment.
_to_edge_no_delegate no longer receives args as a parameter (it was replaced by individual parameters), but the function body at line 815 still references args via _apply_replace_quant_nodes(edge, args). This will raise NameError when the function is called from an importing module rather than from __main__. Similar to _to_edge_TOSA_delegate, the required args fields (target and direct_drive, as used by _apply_replace_quant_nodes) need to be passed explicitly as parameters.
| match quant_mode: | ||
| case QuantMode.INT8: | ||
| operator_config = get_symmetric_quantization_config(is_per_channel=True) | ||
| case QuantMode.A16W8: | ||
| if compile_specs.tosa_spec.support_extension("int16"): | ||
| operator_config = get_symmetric_a16w8_quantization_config( | ||
| is_per_channel=True | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Context TOSA spec {compile_specs.tosa_spec} doesn't support int16" | ||
| ) |
There was a problem hiding this comment.
The match statement does not have a default/wildcard case. If a new QuantMode variant is added in the future, operator_config will be unbound at line 264, causing an UnboundLocalError. Consider adding a wildcard case that raises a ValueError with a descriptive message (e.g., case _: raise ValueError(f"Unsupported quant_mode: {quant_mode}")).
| def _to_edge_TOSA_delegate( | ||
| exported_program: ExportedProgram, | ||
| args, | ||
| compile_spec, | ||
| model: GraphModule, | ||
| quant_mode: Optional[QuantMode], | ||
| example_inputs: Tuple[torch.Tensor], | ||
| model_name: str, | ||
| strict_export: bool, | ||
| ): | ||
| # As we can target multiple output encodings, one must | ||
| # be specified. | ||
| compile_spec = get_compile_spec(args) | ||
|
|
||
| model_quant = None | ||
| if args.quantize: | ||
| if quant_mode is not None: | ||
| model_quant, exported_program = quantize_model( | ||
| args, model, example_inputs, compile_spec | ||
| model, | ||
| example_inputs, | ||
| compile_spec, | ||
| model_name, | ||
| strict_export, | ||
| quant_mode, | ||
| ) | ||
|
|
||
| partitioner = create_partitioner(compile_spec) |
There was a problem hiding this comment.
_to_edge_TOSA_delegate no longer receives args as a parameter (it was replaced by individual parameters), but the function body at line 714 still references args via _apply_replace_quant_nodes(edge, args). This will resolve to the module-level args global when run as __main__, but will raise NameError when imported and called from another module — which is the use case this PR is enabling. The args dependency needs to be replaced, e.g., by passing the required values (args.target and args.direct_drive per _apply_replace_quant_nodes) explicitly as parameters to this function.
The model evaluation feature, i.e. to compute a model's top-1/top-5 accuracy, is to be moved into a new Python program. Some functions in aot_arm_compiler will then be needed to be imported by the evaluation program (to get a similar compilation flow and to reuse existing and working code). Make the functions in aot_arm_compiler "importable" by no longer passing in program args to them; the importing program will not have any such args. Also consistently categorize public/private functions and variables with the "leading underscore convention".
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell