From f2e674e3bb0c8eca81ad317dc01941184a1855cd Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 12 Mar 2026 10:38:55 -0700 Subject: [PATCH 1/2] Migrate executorch/ tests from exir.capture to torch.export + to_edge (#18111) Summary: Migrate test_pass_infra, test_debug_handle_map, test_delegate_map_builder, test_backends_nested, and hta_partitioner_demo to use the torch.export.export() + to_edge() flow instead of the deprecated exir.capture() API. Key changes: - Replace exir.capture(model, inputs, CaptureConfig()).to_edge() with to_edge(export(model, inputs, strict=True)) - Wrap plain functions in nn.Module for torch.export compatibility - Use dynamic debug handle extraction instead of hardcoded values in test_delegate_map_builder (handle numbering changed) - Collapse lifted/unlifted pattern variants in hta_partitioner_demo since torch.export always produces lifted graphs Differential Revision: D95605454 --- exir/backend/test/BUCK | 32 - exir/backend/test/hta_partitioner_demo.py | 85 +- exir/backend/test/test_backends.py | 1463 ----------------- exir/backend/test/test_backends_nested.py | 20 +- exir/backend/test/test_debug_handle_map.py | 43 +- .../backend/test/test_delegate_map_builder.py | 77 +- exir/tests/test_pass_infra.py | 43 +- pytest-windows.ini | 1 - pytest.ini | 1 - 9 files changed, 121 insertions(+), 1644 deletions(-) delete mode 100644 exir/backend/test/test_backends.py diff --git a/exir/backend/test/BUCK b/exir/backend/test/BUCK index 22d5f0b56ba..10278befea0 100644 --- a/exir/backend/test/BUCK +++ b/exir/backend/test/BUCK @@ -158,38 +158,6 @@ fbcode_target(_kind = runtime.python_library, ], ) -fbcode_target(_kind = runtime.python_test, - name = "test_backends", - srcs = [ - "test_backends.py", - ], - preload_deps = [ - "//executorch/configurations:optimized_native_cpu_ops", - "//executorch/kernels/quantized:custom_ops_generated_lib", - "//executorch/runtime/executor/test:test_backend_compiler_lib", - ], - deps = [ - ":backend_with_compiler_demo", - ":hta_partitioner_demo", - ":op_partitioner_demo", - ":demo_backend", - "//caffe2:torch", - "//caffe2/functorch:functorch_src", - "//executorch/exir:delegate", - "//executorch/exir:graph_module", - "//executorch/exir:lib", - "//executorch/exir:lowered_backend_module", - "//executorch/exir:print_program", - "//executorch/exir:schema", - "//executorch/exir/backend:backend_api", - "//executorch/exir/backend:compile_spec_schema", - "//executorch/exir/backend:partitioner", - "//executorch/exir/dialects:lib", - "//executorch/extension/pybindings:portable_lib", # @manual - "//executorch/extension/pytree:pylib", - ], -) - fbcode_target(_kind = runtime.python_test, name = "test_to_backend_multi_method", srcs = [ diff --git a/exir/backend/test/hta_partitioner_demo.py b/exir/backend/test/hta_partitioner_demo.py index ba42c50b0f7..b83b95622a1 100644 --- a/exir/backend/test/hta_partitioner_demo.py +++ b/exir/backend/test/hta_partitioner_demo.py @@ -9,6 +9,7 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( generate_pattern_op_partitions, ) @@ -20,7 +21,7 @@ ) from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.utils import tag_constant_data -from torch.export import ExportedProgram +from torch.export import export, ExportedProgram from torch.fx.passes.infra.partitioner import Partition @@ -63,56 +64,30 @@ def forward(self, x_raw, h, c): input_h = torch.ones([1, 32]) input_c = torch.ones([1, 32]) - pattern_lstm_conv_lifted = ( - exir.capture( - LSTMConvPattern(), - (input_x, input_h, input_c), - exir.CaptureConfig(enable_aot=True), - ) - .to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - .exported_program.graph_module - ) pattern_lstm_conv = ( - exir.capture( - LSTMConvPattern(), - (input_x, input_h, input_c), - exir.CaptureConfig(), - ) - .to_edge( + to_edge( + export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True), # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) - .exported_program.graph_module + .exported_program() + .graph_module ) - def sub(x, y): - return torch.sub(x, y) + class SubModule(torch.nn.Module): + def forward(self, x, y): + return torch.sub(x, y) - pattern_sub_lifted = ( - exir.capture( - sub, - (input_x, input_h), - exir.CaptureConfig(enable_aot=True, _unlift=False), - ) - .to_edge(exir.EdgeCompileConfig(_use_edge_ops=True)) - .exported_program.graph_module - ) pattern_sub = ( - exir.capture( - sub, - (input_x, input_h), - exir.CaptureConfig(), + to_edge( + export(SubModule(), (input_x, input_h), strict=True), + compile_config=exir.EdgeCompileConfig(_use_edge_ops=True), ) - .to_edge() - .exported_program.graph_module + .exported_program() + .graph_module ) self.patterns = [ - pattern_lstm_conv_lifted.graph, pattern_lstm_conv.graph, - pattern_sub_lifted.graph, pattern_sub.graph, ] @@ -239,33 +214,17 @@ def forward(self, x_raw, h, c): input_h = torch.ones([1, 32]) input_c = torch.ones([1, 32]) - pattern_lstm_conv_lifted = ( - exir.capture( - LSTMConvPattern(), - (input_x, input_h, input_c), - exir.CaptureConfig(enable_aot=True), - ) - .to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - .exported_program.graph_module - ) - pattern_lstm_conv_unlifted = ( - exir.capture( - LSTMConvPattern(), - (input_x, input_h, input_c), - exir.CaptureConfig(), - ) - .to_edge( + pattern_lstm_conv = ( + to_edge( + export(LSTMConvPattern(), (input_x, input_h, input_c), strict=True), # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) - .exported_program.graph_module + .exported_program() + .graph_module ) self.patterns = [ - pattern_lstm_conv_lifted.graph, - pattern_lstm_conv_unlifted.graph, + pattern_lstm_conv.graph, ] # Only (lstm + conv) pattern is lowerable diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py deleted file mode 100644 index fa124c855db..00000000000 --- a/exir/backend/test/test_backends.py +++ /dev/null @@ -1,1463 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import operator -import unittest -from typing import Dict, List - -import executorch.exir as exir -import torch -from executorch.exir import to_edge -from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend -from executorch.exir.backend.backend_details import BackendDetails -from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( - AllNodePartitioner, -) -from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.backend.partitioner import ( - DelegationSpec, - Partitioner, - PartitionResult, -) - -# import the backend implementation -from executorch.exir.backend.test.backend_with_compiler_demo import ( - BackendWithCompilerDemo, -) -from executorch.exir.backend.test.demo_backend import DemoBackend -from executorch.exir.backend.test.hta_partitioner_demo import ( - HTAPartitionerMultiplePatternsDemo, - HTAPartitionerOnePatternDemo, -) -from executorch.exir.backend.test.op_partitioner_demo import ( - AddAttributePartitionerDemo, - AddMulPartitionerDemo, -) - -from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.graph_module import get_control_flow_submodules -from executorch.exir.lowered_backend_module import get_lowered_submodules -from executorch.exir.print_program import print_program -from executorch.exir.schema import ( - BackendDelegate, - BackendDelegateDataReference, - DataLocation, - DelegateCall, - Program, -) - -from executorch.extension.pybindings.portable_lib import ( # @manual - _load_for_executorch_from_buffer, -) -from executorch.extension.pytree import tree_flatten - -from functorch.experimental import control_flow -from torch.ao.quantization import get_default_qconfig_mapping # @manual -from torch.ao.quantization.backend_config.executorch import ( - get_executorch_backend_config, -) -from torch.ao.quantization.quantize_fx import ( - _convert_to_reference_decomposed_fx, - prepare_fx, -) -from torch.export import ExportedProgram -from torch.testing import FileCheck - - -def vary_segments(test_method): - """A decorator that calls the test method with `extract_delegate_segments` set to - True and False. - - Decorated test methods must expect a boolean parameter named - `extract_delegate_segments`, and they should pass that value to to_executorch() like: - - m.to_executorch( - config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments) - ) - - This will cause the delegate data blobs to be extracted from the program and - serialized as separate, freeable program segments. Backends should detect no - difference at runtime. - """ - - def wrapper(self): - for extract_delegate_segments in [False, True]: - # subTest will create a different top-level test entry for each - # value, whose full names have a suffix like - # "(extract_delegate_segments=True)". - with self.subTest(extract_delegate_segments=extract_delegate_segments): - test_method(self, extract_delegate_segments=extract_delegate_segments) - - return wrapper - - -class TestBackends(unittest.TestCase): - def check_delegate_input( - self, delegate: LoweredBackendModule, input_len: int - ) -> None: - counter = 0 - for node in delegate.original_module.graph.nodes: - if node.op == "placeholder": - counter += 1 - self.assertEqual(counter, input_len) - - def check_backend_delegate( - self, - program: Program, - delegate: BackendDelegate, - expected_id: str, - expected_processed: bytes, - ) -> None: - self.assertEqual(delegate.id, expected_id) - processed: BackendDelegateDataReference = delegate.processed - self.assertEqual(processed.location, DataLocation.INLINE) - self.assertLess(processed.index, len(program.backend_delegate_data)) - self.assertEqual( - program.backend_delegate_data[processed.index].data, expected_processed - ) - - @vary_segments - def test_backend_with_compiler(self, extract_delegate_segments: bool): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - # TODO(chenlai): add a test with a diffrent method name when - # it's resolved in compiler side. - def forward(self, x): - return torch.sin(x) - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - max_value = model_inputs[0].shape[0] - compile_specs = [CompileSpec("max_value", bytes([max_value]))] - lowered_sin_module = to_backend( - "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs - ) - - class CompositeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered_linear_sin = lowered_sin_module - - def forward(self, x): - return self.lowered_linear_sin(x) - - composite_model = CompositeModule() - model_inputs = (torch.ones(1),) - - composite_model(*model_inputs) - - exec_prog = ( - exir.capture(composite_model, model_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ) - ) - ) - graph_module = exec_prog.dump_graph_module() - - # Check that there is not an aten.sin node. - self.assertTrue( - exir_ops.edge.aten.sin - not in {node.target for node in graph_module.graph.nodes} - ) - - # Check that there exists a call_delegate, representing the call to the - # delegated function - FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( - graph_module.code - ) - lowered_submodules = get_lowered_submodules(graph_module) - self.assertEqual(len(lowered_submodules), 1) - - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target == executorch_call_delegate: - # Check that first arg is lowered_module_{unique_id} - self.assertEqual(node.args[0].target, "lowered_module_0") - - program = exec_prog.program - - # Check the program can be printed - print_program(program) - - # Check the backend delegate - self.check_backend_delegate( - program=program, - delegate=program.execution_plan[0].delegates[0], - expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", - ) - - # Check the delegate instruction - self.assertTrue( - isinstance( - program.execution_plan[0].chains[0].instructions[0].instr_args, - DelegateCall, - ) - ) - buff = exec_prog.buffer - - executorch_module = _load_for_executorch_from_buffer(buff) - model_inputs = torch.ones(1) - model_outputs = executorch_module.forward([model_inputs]) - self.assertEqual( - model_inputs, - torch.ones(1), - ) - expected_output = 0.8333 * torch.ones(1) - - self.assertTrue( - torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) - ) - - @vary_segments - def test_lowered_add_mul(self, extract_delegate_segments: bool): - class AddMulModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, x, b): - y = torch.mm(a, x) - z = torch.add(y, b) - return z - - add_mul_module = AddMulModule() - model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) - edge_graph_module = exir.capture( - add_mul_module, model_inputs, exir.CaptureConfig() - ).to_edge() - max_value = model_inputs[0].shape[0] - compile_specs = [CompileSpec("max_value", bytes([max_value]))] - lowered_add_mul = to_backend( - "BackendWithCompilerDemo", edge_graph_module.exported_program, compile_specs - ) - - class CompositeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered_add_mul = lowered_add_mul - - def forward(self, a, x, b): - return self.lowered_add_mul(a, x, b) - - composite_model = CompositeModule() - - composite_model(*model_inputs) - - exec_prog = ( - exir.capture(composite_model, model_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ) - ) - ) - buff = exec_prog.buffer - - executorch_module = _load_for_executorch_from_buffer(buff) - - inputs_flattened, _ = tree_flatten(model_inputs) - model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) - ref_output = add_mul_module(*model_inputs) - - self.assertTrue( - torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) - ) - - def run_model_in_unsupported_backend(self, extract_delegate_segments: bool): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.sin(x) - - sin_module = SinModule() - # the backend only accepts shape <= 4 - model_inputs = (torch.ones(6),) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - max_value = model_inputs[0].shape[0] - compile_specs = [CompileSpec("max_value", bytes([max_value]))] - lowered_sin_module = to_backend( - "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs - ) - - class CompositeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered_linear_sin = lowered_sin_module - - def forward(self, x): - return self.lowered_linear_sin(x) - - composite_model = CompositeModule() - model_inputs = (torch.zeros(6),) - - composite_model(*model_inputs) - - exec_prog = ( - exir.capture(composite_model, model_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - ) - - buff = exec_prog.buffer - executorch_module = _load_for_executorch_from_buffer(buff) - # This line should raise an exception like - # RuntimeError: failed with error 0x12 - inputs_flattened, _ = tree_flatten(model_inputs) - executorch_module.run_method("forward", tuple(inputs_flattened)) - - @vary_segments - def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool): - with self.assertRaisesRegex( - RuntimeError, - "Failed to execute method forward, error: 0x12", - ): - self.run_model_in_unsupported_backend( - extract_delegate_segments=extract_delegate_segments - ) - - @vary_segments - def test_backend_with_compiler_delegate_and_operator( - self, extract_delegate_segments: bool - ): - # Test includes both delegates and operator - # import the backend implementation - from executorch.exir.backend.test.backend_with_compiler_demo import ( - BackendWithCompilerDemo, - ) - - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - # TODO(chenlai): add a test with a diffrent method name when - # it's resolved in compiler side. - def forward(self, x): - return [torch.sin(x)] - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - max_value = model_inputs[0].shape[0] - compile_specs = [CompileSpec("max_value", bytes([max_value]))] - lowered_sin_module = to_backend( - "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs - ) - - class CompositeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered_linear_sin = lowered_sin_module - - def forward(self, x): - a = self.lowered_linear_sin(x)[0] - b = self.lowered_linear_sin(x)[0] - return torch.add(a, b) - - composite_model = CompositeModule() - model_inputs = (torch.ones(1),) - - composite_model(*model_inputs) - - exec_prog = ( - exir.capture(composite_model, model_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - ) - graph_module = exec_prog.dump_graph_module() - program = exec_prog.program - buff = exec_prog.buffer - - # Check that there is not an aten.sin node. - self.assertTrue( - exir_ops.edge.aten.sin.default - not in {node.target for node in graph_module.graph.nodes} - ) - - # Check that there exists a call_delegate op, representing the call to the - # delegated function - FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( - graph_module.code - ) - - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target == executorch_call_delegate: - # Check that first arg is lowered_module_{unique_id} - self.assertEqual(node.args[0].target, "lowered_module_0") - - # Check the backend delegate - self.check_backend_delegate( - program=program, - delegate=program.execution_plan[0].delegates[0], - expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", - ) - - # Check the delegate instruction - self.assertTrue( - isinstance( - program.execution_plan[0].chains[0].instructions[0].instr_args, - DelegateCall, - ) - ) - - executorch_module = _load_for_executorch_from_buffer(buff) - model_inputs = torch.ones(1) - - model_outputs = executorch_module.forward([model_inputs]) - - self.assertEqual( - model_inputs, - torch.ones(1), - ) - expected_output = 1.666667 * torch.ones(1) - - self.assertTrue( - torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) - ) - - def test_backend_with_compiler_backend_runtime_exception(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - # TODO(chenlai): add a test with a diffrent method name when - # it's resolved in compiler side. - def forward(self, x): - return torch.sin(x) + torch.cos(x) - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo" - - with self.assertRaisesRegex( - RuntimeError, - error_msg, - ): - _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program, []) - - def test_backend_with_compiler_backend_not_found_exception(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - # TODO(chenlai): add a test with a diffrent method name when - # it's resolved in compiler side. - def forward(self, x): - return torch.sin(x) + torch.cos(x) - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - error_msg = r"Backend FakeBackendWithCompilerDemo was not found." - - with self.assertRaisesRegex( - NotImplementedError, - error_msg, - ): - _ = to_backend("FakeBackendWithCompilerDemo", edgeir_m.exported_program, []) - - @vary_segments - def test_backend_with_compiler_delegate_and_operator_with_two_modules( - self, extract_delegate_segments: bool - ): - # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend - class LowerableSubModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.sin(x) - - # sin_module is an nn.Module - to_be_lowered = LowerableSubModel() - example_input = (torch.ones(1),) - to_be_lowered_exir_submodule = exir.capture( - to_be_lowered, example_input, exir.CaptureConfig() - ).to_edge() - - max_value = example_input[0].shape[0] - compile_specs = [CompileSpec("max_value", bytes([max_value]))] - lowered_module = to_backend( - "BackendWithCompilerDemo", - to_be_lowered_exir_submodule.exported_program, - compile_specs, - ) - - class NonLowerableSubModel(torch.nn.Module): - def __init__(self, bias): - super().__init__() - self.bias = bias - - def forward(self, a, b): - return torch.add(torch.add(a, b), self.bias) - - # the composite modules, including lower part and non-lowerpart - class CompositeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3) - self.lowerable = lowered_module - - def forward(self, x): - a = self.lowerable(x) - b = self.lowerable(a) - ret = self.non_lowerable(a, b) - return a, b, ret - - composite_model = CompositeModel() - - # Prepare the model input - model_inputs = (torch.ones(1),) - - # Verify the input works with eager module - composite_model(*model_inputs) - - exec_prog = ( - exir.capture(composite_model, model_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - ) - flatbuffer = exec_prog.buffer - - executorch_module = _load_for_executorch_from_buffer(flatbuffer) - model_outputs = executorch_module.forward([*model_inputs]) - - expected_outputs = [ - 0.8333 * torch.ones(1), - 0.7369 * torch.ones(1), - 1.8702 * torch.ones(1), - ] - - for index, expected_output in enumerate(expected_outputs): - self.assertTrue( - torch.allclose( - model_outputs[index], expected_output, atol=1e-03, rtol=1e-03 - ) - ) - - @vary_segments - def test_partition_delegate_graph_with_multiple_patterns( - self, extract_delegate_segments: bool - ): - class CompositeModel(torch.nn.Module): - def __init__(self, _weight): - super().__init__() - self.weight = _weight - self.lstm = torch.nn.LSTM( - input_size=32, - hidden_size=32, - num_layers=1, - ) - self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) - - def forward(self, x_raw, h, c): - output, (hn, cn) = self.lstm(x_raw, (h, c)) - k = self.conv(output) - x = output - y = cn - a = torch.sub(x, y) - b = torch.sub(x, a) - c = torch.sub(x, b) - d = torch.add(x, self.weight) - e = torch.mul(c, d) - return e, hn, k - - # Prepare input and trace it - input_x = torch.ones([1, 32]) - input_h = torch.ones([1, 32]) - input_c = torch.ones([1, 32]) - inputs = (input_x, input_h, input_c) - - composite_m = CompositeModel(3) - orig_res = composite_m(*inputs) - - traced = exir.capture(composite_m, inputs, exir.CaptureConfig()).to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - - program_without_delegates = ( - exir.capture(CompositeModel(3), inputs) - .to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - ) - # after this step, part of the graph will be lowered to backend, depending on - # HTAPartitionerDemo's rule. - program_with_delegates = traced - program_with_delegates.exported_program = to_backend( - traced.exported_program, HTAPartitionerMultiplePatternsDemo() - ) - program_with_delegates = program_with_delegates.to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - - new_res = program_with_delegates.dump_graph_module()(*inputs) - for t1, t2 in zip(new_res, orig_res, strict=True): - self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) - - # Check the backend delegate - self.check_backend_delegate( - program=program_with_delegates.program, - delegate=program_with_delegates.program.execution_plan[0].delegates[0], - expected_id=DemoBackend.__name__, - expected_processed=b"imqnncompiled", - ) - - # Check add not in the program with delegates - self.assertEqual( - 0, - len( - [ - op - for op in program_with_delegates.program.execution_plan[0].operators - if op.name == "aten::sub" - ] - ), - ) - - # Check convolution not in the program with delegates - self.assertEqual( - 0, - len( - [ - op - for op in program_with_delegates.program.execution_plan[0].operators - if op.name == "aten::convolution" - ] - ), - ) - - # Check convolution in the program without delegates - self.assertEqual( - 1, - len( - [ - op - for op in program_without_delegates.program.execution_plan[ - 0 - ].operators - if op.name == "aten::convolution" - ] - ), - ) - - @vary_segments - def test_partition_delegate_graph_with_one_patterns( - self, extract_delegate_segments: bool - ): - class CompositeModel(torch.nn.Module): - def __init__(self, _weight): - super().__init__() - self.weight = _weight - self.lstm = torch.nn.LSTM( - input_size=32, - hidden_size=32, - num_layers=1, - ) - self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) - - def forward(self, x_raw, h, c): - output, (hn, cn) = self.lstm(x_raw, (h, c)) - k = self.conv(output) - x = output - y = cn - a = torch.sub(x, y) - b = torch.sub(x, a) - c = torch.sub(x, b) - d = torch.add(x, self.weight) - e = torch.mul(c, d) - return e, hn, k - - # Prepare input and trace it - input_x = torch.ones([1, 32]) - input_h = torch.ones([1, 32]) - input_c = torch.ones([1, 32]) - inputs = (input_x, input_h, input_c) - - composite_m = CompositeModel(3) - orig_res = composite_m(*inputs) - - traced = exir.capture( - composite_m, - inputs, - exir.CaptureConfig(), - ).to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - - program_without_delegates = ( - exir.capture( - CompositeModel(3), - (input_x, input_h, input_c), - exir.CaptureConfig(), - ) - .to_edge( - # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. - exir.EdgeCompileConfig(_check_ir_validity=False) - ) - .to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - ) - # after this step, part of the graph will be lowered to backend, depending on - # HTAPartitionerDemo's rule. - traced_with_delegate = traced - traced_with_delegate.exported_program = to_backend( - traced.exported_program, HTAPartitionerOnePatternDemo() - ) - - new_res = traced_with_delegate(*inputs) - for t1, t2 in zip(new_res, orig_res, strict=True): - self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) - - program_with_delegates = traced_with_delegate.to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - - # TODO(T143084047): Currently not retraceable - # Retracing is not needed, but keeping this here to make sure the result - # of to_backend is retraceable - # graph_module_with_delegate = exir.capture( - # traced_with_delegate, - # (input_x, input_h, input_c), - # exir.CaptureConfig(), - # ).to_edge() - - # program_with_delegates = graph_module_with_delegate.to_executorch( - # config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments), - # ) - - new_res = program_with_delegates.dump_graph_module()(*inputs) - for t1, t2 in zip(new_res, orig_res, strict=True): - self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) - - # Check the backend delegate - self.check_backend_delegate( - program=program_with_delegates.program, - delegate=program_with_delegates.program.execution_plan[0].delegates[0], - expected_id=DemoBackend.__name__, - expected_processed=b"imqnncompiled", - ) - - # Check add is in the program with delegates - self.assertEqual( - 1, - len( - [ - op - for op in program_with_delegates.program.execution_plan[0].operators - if op.name == "aten::sub" - ] - ), - ) - - # Check convolution not in the program with delegates - self.assertEqual( - 0, - len( - [ - op - for op in program_with_delegates.program.execution_plan[0].operators - if op.name == "aten::convolution" - ] - ), - ) - - # Check convolution in the program without delegates - self.assertEqual( - 1, - len( - [ - op - for op in program_without_delegates.program.execution_plan[ - 0 - ].operators - if op.name == "aten::convolution" - ] - ), - ) - - @vary_segments - def test_add_mul_partitioner(self, extract_delegate_segments: bool): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, x, b): - y = torch.mm(a, x) - z = y + b - a = z - a - y = torch.mm(a, x) - z = y + b - return z - - m = Model() - inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) - orig_res = m(*inputs) - - ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge() - executorch_prog = ep - executorch_prog.exported_program = to_backend( - ep.exported_program, AddMulPartitionerDemo() - ) - - for node in executorch_prog.exported_program.graph.nodes: - if node.op == "call_function" and node.target is executorch_call_delegate: - for user in node.users: - self.assertTrue( - user.op == "call_function" and user.target == operator.getitem - ) - self.assertTrue(user.meta.get("source_fn_stack", None) is None) - self.assertTrue(user.meta.get("nn_module_stack", None) is None) - - executorch_prog = executorch_prog.to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - - new_res = executorch_prog.dump_graph_module()(*inputs) - self.assertTrue(torch.allclose(new_res[0], orig_res)) - - counter = 0 - for node in executorch_prog.dump_graph_module().graph.nodes: - if node.op == "get_attr": - self.assertEqual(node.target, f"lowered_module_{counter}") - counter += 1 - # There should be 2 delegated modules - self.assertEqual(counter, 2) - - executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer) - inputs_flattened, _ = tree_flatten(inputs) - model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) - ref_output = m(*inputs) - - self.assertTrue( - torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03), - ) - - @vary_segments - def test_partitioner_with_attributes(self, extract_delegate_segments: bool): - """ - Check that if we tag the getattr nodes, the attributes will be added to - the lowered submodule rather than being passed into the delegate as - inputs. - """ - - class AddOne(torch.nn.Module): - def __init__(self): - super().__init__() - self.one = torch.ones(1, 3) - - def forward(self, x): - return x + self.one - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.add_one = AddOne() - - def forward(self, x, y): - x = self.add_one(x) * y - return self.add_one(x), self.add_one(y) - - inputs = (torch.randn(1, 3), torch.randn(1, 3)) - orig_res = Model()(*inputs) - ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge() - executorch_prog = ep - executorch_prog.exported_program = to_backend( - ep.exported_program, AddAttributePartitionerDemo() - ) - - for node in executorch_prog.exported_program.graph.nodes: - if node.op == "call_function" and node.target is executorch_call_delegate: - for user in node.users: - self.assertTrue( - user.op == "call_function" and user.target == operator.getitem - ) - self.assertTrue(user.meta.get("source_fn_stack", None) is None) - self.assertTrue(user.meta.get("nn_module_stack", None) is None) - - executorch_prog = executorch_prog.to_executorch( - config=exir.ExecutorchBackendConfig( - extract_delegate_segments=extract_delegate_segments - ), - ) - - # Check the delegated submodules - lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module()) - self.assertEqual(len(lowered_submodules), 2) - # Attributes should be stored in the lowered module - self.check_delegate_input(lowered_submodules[0][1], 1) - self.check_delegate_input(lowered_submodules[1][1], 2) - - executorch_prog.buffer - - new_res = executorch_prog.dump_graph_module()(*inputs) - self.assertTrue(torch.allclose(orig_res[0], new_res[0])) - self.assertTrue(torch.allclose(orig_res[1], new_res[1])) - - def test_bad_partitioner(self): - """ - Checks that we throw an error if user provided partitioner modifies the - graph module - """ - inputs = (torch.randn(1, 3), torch.randn(1, 3)) - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - x = x + y - x = x * y - x = x - y - x = x / y - x = x * y - x = x + y - return x - - class BadPartitioner(Partitioner): - def partition(self, exported_program: ExportedProgram) -> PartitionResult: - # Partitioner should not modify the given graph module - for node in exported_program.graph.nodes: - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.add.Tensor - ): - node.target = exir_ops.edge.aten.mul.Tensor - return PartitionResult( - tagged_exported_program=exported_program, - partition_tags={ - "tag1": DelegationSpec("BackendWithCompilerDemo", []) - }, - ) - - ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge() - with self.assertRaises(AssertionError): - _ = to_backend(ep.exported_program, BadPartitioner()) - - def test_quantized_with_delegate(self) -> None: - torch.ops.load_library( - "//executorch/kernels/quantized:custom_ops_generated_lib" - ) - qconfig_mapping = get_default_qconfig_mapping("qnnpack") - in_size = 2 - input_size = 3 - output_size = 4 - linear = torch.nn.Linear(input_size, output_size).eval() - example_inputs = (torch.ones(in_size, input_size),) - prepared_linear = prepare_fx( - linear, - qconfig_mapping, - example_inputs, - backend_config=get_executorch_backend_config(), - ) - converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx( - prepared_linear, - ) - - # fails to trace here - converted_linear_gm = exir.capture( - converted_linear, - example_inputs, - exir.CaptureConfig( - enable_aot=True, - ), - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) - FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run( - converted_linear_gm.exported_program.graph_module.code - ) - - def test_partition_with_control_flow(self) -> None: - def true_fn(x, y): - x = x - y - x = x + y - x = x - y - return x - - def false_fn(x, y): - x = x - y - x = torch.mm(x, y) - x = x - y - return x - - def f(x, y): - x = x + y - x = torch.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) - x = x - y - return x - - inputs = (torch.ones(2, 2), torch.ones(2, 2)) - orig_res = f(*inputs) - orig = exir.capture( - f, - inputs, - exir.CaptureConfig(), - ).to_edge() - partitioned = orig - partitioned.exported_program = to_backend( - orig.exported_program, AddMulPartitionerDemo() - ) - - new_res = partitioned(*inputs) - self.assertTrue(torch.allclose(orig_res, new_res[0])) - - toplevel_lowered = get_lowered_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(toplevel_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( - toplevel_lowered[0][1].original_module.graph_module.code - ) - - # Toplevel module only has the cond submodules - partitioned_submodules = get_control_flow_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(partitioned_submodules), 2) - - true_gm = partitioned_submodules[0][1] - true_lowered = get_lowered_submodules(true_gm) - self.assertEqual(len(true_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( - true_lowered[0][1].original_module.graph_module.code - ) - - false_gm = partitioned_submodules[1][1] - false_lowered = get_lowered_submodules(false_gm) - self.assertEqual(len(true_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( - false_lowered[0][1].original_module.graph_module.code - ) - - def test_partition_with_map(self) -> None: - def map_fn(x, y): - x = x - y - x = x + y - return x - - def f(xs, y): - y = torch.mm(y, y) - return control_flow.map(map_fn, xs, y) - - inputs = (torch.ones(2, 2), torch.ones(2, 2)) - orig_res = f(*inputs) - orig = exir.capture( - f, - inputs, - exir.CaptureConfig(), - ).to_edge() - partitioned = orig - partitioned.exported_program = to_backend( - orig.exported_program, AddMulPartitionerDemo() - ) - - toplevel_lowered = get_lowered_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(toplevel_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( - toplevel_lowered[0][1].original_module.graph_module.code - ) - - # Toplevel module only has the map submodule - partitioned_submodules = get_control_flow_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(partitioned_submodules), 1) - - map_fn_gm = partitioned_submodules[0][1] - map_fn_lowered = get_lowered_submodules(map_fn_gm) - self.assertEqual(len(map_fn_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( - map_fn_lowered[0][1].original_module.graph_module.code - ) - - new_res = partitioned(*inputs) - - self.assertTrue(torch.allclose(orig_res, new_res[0])) - - def test_partition_with_nested_control_flow(self) -> None: - """ - Partitions the add and mul ops, including the ones inside the submodules - """ - - def true_nested(y): - y = y + y - y = torch.mm(y, y) - return y - - def false_nested(y): - return torch.mm(y, y) - - def true_fn(x, pred2): - z = control_flow.cond(pred2, true_nested, false_nested, [x]) - return x + z - - def false_fn(x, _): - return x.cos() - - def map_fn(x, pred1, pred2, y): - x = x.cos() - y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) - x = x + y - return x.sin() - - def f(xs, pred1, pred2, y): - y = torch.mm(y, y) - return control_flow.map(map_fn, xs, pred1, pred2, y) - - inputs = ( - torch.ones(2, 2), - torch.tensor([False]), - torch.Tensor([False]), - torch.ones(2, 2), - ) - - orig_res = f(*inputs) - orig = exir.capture( - f, - inputs, - exir.CaptureConfig(), - ).to_edge() - partitioned = orig - partitioned.exported_program = to_backend( - orig.exported_program, AddMulPartitionerDemo() - ) - - new_res = partitioned(*inputs) - self.assertTrue(torch.allclose(orig_res, new_res[0])) - - toplevel_lowered = get_lowered_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(toplevel_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( - toplevel_lowered[0][1].original_module.graph_module.code - ) - - # Toplevel module only has the map submodule - partitioned_submodules = get_control_flow_submodules( - partitioned.exported_program.graph_module - ) - self.assertEqual(len(partitioned_submodules), 1) - - # Map module has the cond submodules - map_submodules = get_control_flow_submodules(partitioned_submodules[0][1]) - self.assertEqual(len(map_submodules), 2) - - # True module - true_module = map_submodules[0][1] - true_lowered = get_lowered_submodules(true_module) - self.assertEqual(len(true_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( - true_lowered[0][1].original_module.graph_module.code - ) - - # False module - false_lowered = get_lowered_submodules(map_submodules[1][1]) - self.assertEqual(len(false_lowered), 0) - - # True module has the nested cond submodules - true_submodules = get_control_flow_submodules(true_module) - self.assertEqual(len(true_submodules), 2) - - # Nested True module - true_true_lowered = get_lowered_submodules(true_submodules[0][1]) - self.assertEqual(len(true_true_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check( - "executorch_exir_dialects_edge__ops_aten_mm_default" - ).run(true_true_lowered[0][1].original_module.graph_module.code) - - # Nested False module - true_false_lowered = get_lowered_submodules(true_submodules[1][1]) - self.assertEqual(len(true_false_lowered), 1) - FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( - true_false_lowered[0][1].original_module.graph_module.code - ) - - def test_list_input(self): - def f(x: List[torch.Tensor]): - y = x[0] + x[1] - return y - - inputs = ([torch.randn(2, 2), torch.randn(2, 2)],) - edge_prog = exir.capture(f, inputs, exir.CaptureConfig()).to_edge() - lowered_gm = to_backend( - BackendWithCompilerDemo.__name__, edge_prog.exported_program, [] - ) - - class ComposedM(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered = lowered_gm - - def forward(self, x: List[torch.Tensor]): - return self.lowered(x) - - gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() - gm(*inputs) - - def test_dict_input(self): - class M(torch.nn.Module): - def forward(self, x: Dict[str, torch.Tensor]): - y = x["a"] + x["b"] - return y - - inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) - edge_prog = exir.to_edge(torch.export.export(M(), inputs, strict=True)) - lowered_gm = to_backend( - BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] - ) - - class ComposedM(torch.nn.Module): - def __init__(self): - super().__init__() - self.lowered = lowered_gm - - def forward(self, x: List[torch.Tensor]): - return self.lowered(x) - - gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() - gm(*inputs) - - def test_to_backend_delegation_spec(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return [torch.sin(x)] - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - max_value = model_inputs[0].shape[0] - - partitioner = AllNodePartitioner( - "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))] - ) - - edgeir_m = to_edge(torch.export.export(sin_module, model_inputs)) - edgeir_m = edgeir_m.to_backend(partitioner) - exec_prog = edgeir_m.to_executorch() - graph_module = exec_prog.exported_program().graph_module - # Check that there is not an aten.sin node. - self.assertTrue( - exir_ops.edge.aten.sin - not in {node.target for node in graph_module.graph.nodes} - ) - - # Check that there exists a call_delegate, representing the call to the - # delegated function - FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( - graph_module.code - ) - lowered_submodules = get_lowered_submodules(graph_module) - self.assertEqual(len(lowered_submodules), 1) - - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target == executorch_call_delegate: - # Check that first arg is lowered_module_{unique_id} - self.assertEqual(node.args[0].target, "lowered_module_0") - - program = exec_prog.executorch_program - - # Check the program can be printed - print_program(program) - - # Check the backend delegate - self.check_backend_delegate( - program=program, - delegate=program.execution_plan[0].delegates[0], - expected_id=BackendWithCompilerDemo.__name__, - expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", - ) - - # Check the delegate instruction - self.assertTrue( - isinstance( - program.execution_plan[0].chains[0].instructions[0].instr_args, - DelegateCall, - ) - ) - buff = exec_prog.buffer - - executorch_module = _load_for_executorch_from_buffer(buff) - model_inputs = torch.ones(1) - model_outputs = executorch_module.forward([model_inputs]) - self.assertEqual( - model_inputs, - torch.ones(1), - ) - expected_output = 0.8333 * torch.ones(1) - - self.assertTrue( - torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) - ) - - def test_to_backend_multimethod_delegation_spec(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.sin(x) - - def inputs(self): - return (torch.ones(1),) - - class AddMulModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, a, x, b): - y = torch.mm(a, x) - z = torch.add(y, b) - return z - - def inputs(self): - return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) - - sin_module = SinModule() - max_value_sin = sin_module.inputs()[0].shape[0] - sin_partitioner = AllNodePartitioner( - "BackendWithCompilerDemo", - [CompileSpec("max_value", bytes([max_value_sin]))], - ) - - add_mul_module = AddMulModule() - max_value_add_mul = add_mul_module.inputs()[0].shape[0] - add_mul_partitioner = AllNodePartitioner( - "BackendWithCompilerDemo", - [CompileSpec("max_value", bytes([max_value_add_mul]))], - ) - - edgeir_m = to_edge( - { - "sin": torch.export.export(sin_module, sin_module.inputs()), - "add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()), - } - ) - edgeir_m = edgeir_m.to_backend( - { - "sin": sin_partitioner, - "add_mul": add_mul_partitioner, - } - ) - exec_prog = edgeir_m.to_executorch() - - for method_name in ["sin", "add_mul"]: - graph_module = exec_prog.exported_program(method_name).graph_module - # Check delegated nodes are gone - self.assertTrue( - exir_ops.edge.aten.sin - not in {node.target for node in graph_module.graph.nodes} - ) - self.assertTrue( - exir_ops.edge.aten.add - not in {node.target for node in graph_module.graph.nodes} - ) - self.assertTrue( - exir_ops.edge.aten.mm - not in {node.target for node in graph_module.graph.nodes} - ) - # Check that there exists a call_delegate, representing the call to the - # delegated function - FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( - graph_module.code - ) - lowered_submodules = get_lowered_submodules(graph_module) - self.assertEqual(len(lowered_submodules), 1) - - program = exec_prog.executorch_program - - # Check the program can be printed - print_program(program) - - buff = exec_prog.buffer - - executorch_module = _load_for_executorch_from_buffer(buff) - - for method_name, module in { - "sin": sin_module, - "add_mul": add_mul_module, - }.items(): - inputs_flattened, _ = tree_flatten(module.inputs()) - model_outputs = executorch_module.run_method( - method_name, tuple(inputs_flattened) - ) - - if method_name == "sin": - # backend with compiler demo does a taylor approximation of sin - ref_output = 0.8333 * torch.ones(1) - else: - ref_output = module(*module.inputs()) - self.assertTrue( - torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) - ) - - def test_prohibited_nested_backends(self): - class MyBackend(BackendDetails): - @staticmethod - def preprocess(edge_program, compile_specs): - return None - - with self.assertRaises(TypeError) as ctx: - - class MyOtherBackend(MyBackend): - pass - - self.assertIn( - "'MyBackend' should be a final backend implementation and should not be subclassed (attempted by 'MyOtherBackend')", - str(ctx.exception), - ) diff --git a/exir/backend/test/test_backends_nested.py b/exir/backend/test/test_backends_nested.py index 5751706959b..c4bf44aab44 100644 --- a/exir/backend/test/test_backends_nested.py +++ b/exir/backend/test/test_backends_nested.py @@ -12,6 +12,7 @@ import torch +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -33,7 +34,7 @@ from executorch.exir.graph_module import _get_submodule, get_control_flow_submodules from executorch.exir.lowered_backend_module import get_lowered_submodules from functorch.experimental import control_flow -from torch.export import ExportedProgram +from torch.export import export, ExportedProgram from torch.fx.passes.operator_support import any_chain, OperatorSupportBase @@ -221,23 +222,22 @@ def test(self) -> None: m = M() orig_res = m(*m.get_example_inputs()) - orig = exir.capture( - m, - m.get_example_inputs(), - exir.CaptureConfig(), - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) + orig = to_edge( + export(m, m.get_example_inputs(), strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), + ) partitioned = orig - partitioned.exported_program = to_backend( - orig.exported_program, Backend1PartitionerDemo() + partitioned._edge_programs["forward"] = to_backend( + orig.exported_program(), Backend1PartitionerDemo() ) - new_res = partitioned(*m.get_example_inputs())[0] + new_res = partitioned.exported_program().module()(*m.get_example_inputs())[0] self.assertTrue(torch.allclose(orig_res, new_res)) # The toplevel module should have lowered the cond and add op toplevel_lowered = get_lowered_submodules( - partitioned.exported_program.graph_module + partitioned.exported_program().graph_module ) self.assertEqual(len(toplevel_lowered), 1) toplevel_lowered = toplevel_lowered[0][1] diff --git a/exir/backend/test/test_debug_handle_map.py b/exir/backend/test/test_debug_handle_map.py index a82207239ac..df4af5afd34 100644 --- a/exir/backend/test/test_debug_handle_map.py +++ b/exir/backend/test/test_debug_handle_map.py @@ -10,11 +10,13 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.test.demo_backend import DemoBackend from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo from executorch.exir.delegate import executorch_call_delegate -from hypothesis import given, settings, strategies as st +from hypothesis import settings +from torch.export import export class TestBackendDebugHandle(unittest.TestCase): @@ -34,14 +36,14 @@ def forward(self, a, x, b): m = Model() inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) - ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge() + ep = to_edge(export(m, inputs, strict=True)) executorch_prog = ep - executorch_prog.exported_program = to_backend( - ep.exported_program, AddMulPartitionerDemo() + executorch_prog._edge_programs["forward"] = to_backend( + ep.exported_program(), AddMulPartitionerDemo() ) lowered_nodes = [ - getattr(executorch_prog.exported_program.graph_module, node.target) - for node in executorch_prog.exported_program.graph.nodes + getattr(executorch_prog.exported_program().graph_module, node.target) + for node in executorch_prog.exported_program().graph.nodes if node.op == "get_attr" ] for lowered_node in lowered_nodes: @@ -49,18 +51,15 @@ def forward(self, a, x, b): call_delegate_nodes = [ node - for node in executorch_prog.exported_program.graph.nodes + for node in executorch_prog.exported_program().graph.nodes if node.target == executorch_call_delegate ] for call_delegate_node in call_delegate_nodes: self.assertIsNotNone(call_delegate_node.meta["debug_handle"]) - @given( - unlift=st.booleans(), # verify both lifted and unlifted graph - ) @settings(deadline=500000) - def test_lowered_the_whole_model(self, unlift): + def test_lowered_the_whole_model(self): module_list = [ models.Emformer(), models.Repeat(), @@ -69,10 +68,6 @@ def test_lowered_the_whole_model(self, unlift): models.ModelWithUnusedArg(), ] - capture_config = ( - exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() - ) - edge_compile_config = exir.EdgeCompileConfig( _check_ir_validity=False, _use_edge_ops=True ) @@ -80,11 +75,12 @@ def test_lowered_the_whole_model(self, unlift): for model in module_list: model_inputs = model.get_random_inputs() - edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( - edge_compile_config + edgeir_m = to_edge( + export(model, model_inputs, strict=True), + compile_config=edge_compile_config, ) lowered_model = to_backend( - DemoBackend.__name__, edgeir_m.exported_program, [] + DemoBackend.__name__, edgeir_m.exported_program(), [] ) # DemoBackend compile all nodes as one node. The debug_handle_map will be like (1: (debug handle from all nodes)) @@ -114,12 +110,13 @@ def __init__(self, lowered_model): def forward(self, *args): return self.back_bone(*args) - edge = exir.capture( - ComposedModel(lowered_model), model_inputs, capture_config - ).to_edge(edge_compile_config) + edge = to_edge( + export(ComposedModel(lowered_model), model_inputs, strict=True), + compile_config=edge_compile_config, + ) lowered_nodes = [ - getattr(edge.exported_program.graph_module, node.target) - for node in edge.exported_program.graph.nodes + getattr(edge.exported_program().graph_module, node.target) + for node in edge.exported_program().graph.nodes if node.op == "get_attr" ] for lowered_node in lowered_nodes: diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index 2c30e4d9531..5d63493684c 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -9,12 +9,14 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.test.backend_with_delegate_mapping_demo import ( BackendWithDelegateMappingDemo, ) from executorch.exir.backend.utils import DelegateMappingBuilder +from torch.export import export class TestDelegateMapBuilder(unittest.TestCase): @@ -30,46 +32,61 @@ def forward(self, x): model = Model() model_inputs = (torch.ones(1, 1),) - program = ( - exir.capture(model, model_inputs, exir.CaptureConfig(pt2_mode=True)) - .to_edge() - .to_executorch() - ) + program = to_edge(export(model, model_inputs, strict=True)).to_executorch() # Create nodes for testing mapping - # nodes: [arg0_1, alloc, aten_sin_default, alloc_1, aten_cos_default, output] - # debug handles: [None, None, 1, None, 2, None] - self.nodes = list(program.graph_module.graph.nodes) + self.nodes = list(program.exported_program().graph_module.graph.nodes) self.handles = [node.meta.get("debug_handle") for node in self.nodes] + # Extract the actual debug handle values for sin and cos nodes + non_none_handles = [h for h in self.handles if h is not None] + self.sin_handle = non_none_handles[0] + self.cos_handle = non_none_handles[1] + # Find the index of the sin node in self.nodes + self.sin_node_idx = next( + i + for i, n in enumerate(self.nodes) + if n.meta.get("debug_handle") == self.sin_handle + ) + self.cos_node_idx = next( + i + for i, n in enumerate(self.nodes) + if n.meta.get("debug_handle") == self.cos_handle + ) def test_basic_generated_identifier(self): delegate_builder = DelegateMappingBuilder(generated_identifiers=True) + sh, ch = self.sin_handle, self.cos_handle - expected_mapping = {0: (1, 2)} + expected_mapping = {0: (sh, ch)} self.assertEqual( delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0 ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (1, 2), 1: (1,)} + expected_mapping = {0: (sh, ch), 1: (sh,)} self.assertEqual( - delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[2]), 1 + delegate_builder.insert_delegate_mapping_entry( + nodes=self.nodes[self.sin_node_idx] + ), + 1, ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) - expected_mapping = {0: (1, 2), 1: (1,), 2: (2,)} + expected_mapping = {0: (sh, ch), 1: (sh,), 2: (ch,)} self.assertEqual( - delegate_builder.insert_delegate_mapping_entry(handles=self.handles[4]), + delegate_builder.insert_delegate_mapping_entry( + handles=self.handles[self.cos_node_idx] + ), 2, ) self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) expected_mapping = { - 0: (1, 2), - 1: (1,), - 2: (2,), - 3: (1, 2), + 0: (sh, ch), + 1: (sh,), + 2: (ch,), + 3: (sh, ch), } self.assertEqual( delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3 @@ -115,7 +132,7 @@ def test_omitting_identifier_when_not_generated(self): def test_reinsert_delegate_debug_identifier(self): delegate_builder = DelegateMappingBuilder() delegate_builder.insert_delegate_mapping_entry( - nodes=self.nodes[2], identifier="1" + nodes=self.nodes[self.sin_node_idx], identifier="1" ) self.assertRaises( @@ -134,23 +151,24 @@ def test_reinsert_delegate_debug_identifier(self): self.assertRaises( Exception, lambda: delegate_builder.insert_delegate_mapping_entry( - nodes=self.nodes[2], identifier="1" + nodes=self.nodes[self.sin_node_idx], identifier="1" ), ) self.assertRaises( Exception, lambda: delegate_builder.insert_delegate_mapping_entry( - handles=self.handles[2], identifier="1" + handles=self.handles[self.sin_node_idx], identifier="1" ), ) def test_backend_with_delegate_mapping(self) -> None: model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs() - edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge( - exir.EdgeCompileConfig(_check_ir_validity=False) + edgeir_m = to_edge( + export(model, inputs, strict=True), + compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), ) lowered_module = to_backend( - "BackendWithDelegateMappingDemo", edgeir_m.exported_program, [] + "BackendWithDelegateMappingDemo", edgeir_m.exported_program(), [] ) debug_handle_map = lowered_module.meta.get("debug_handle_map") self.assertIsNotNone(debug_handle_map) @@ -172,9 +190,7 @@ def forward(self, x): composite_model = CompositeModule() # TODO: Switch this to lowered_module.program() once lowered_module has support # for storing debug delegate identifier maps. - exir.capture( - composite_model, inputs, exir.CaptureConfig() - ).to_edge().to_executorch() + to_edge(export(composite_model, inputs, strict=True)).to_executorch() def test_passing_both_nodes_and_handles(self): delegate_builder = DelegateMappingBuilder() @@ -211,10 +227,11 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): delegate_builder_nodes = DelegateMappingBuilder() delegate_builder_handles = DelegateMappingBuilder() + sh, ch = self.sin_handle, self.cos_handle # Entry with a list of nodes iden_1 = next(identifiers) - expected_mapping = {iden_1: (1, 2)} + expected_mapping = {iden_1: (sh, ch)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( nodes=self.nodes, identifier=iden_1 @@ -236,16 +253,16 @@ def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): # Entry with a single node iden_2 = next(identifiers) - expected_mapping = {iden_1: (1, 2), iden_2: (1,)} + expected_mapping = {iden_1: (sh, ch), iden_2: (sh,)} self.assertEqual( delegate_builder_nodes.insert_delegate_mapping_entry( - nodes=self.nodes[2], identifier=iden_2 + nodes=self.nodes[self.sin_node_idx], identifier=iden_2 ), iden_2, ) self.assertEqual( delegate_builder_handles.insert_delegate_mapping_entry( - handles=self.handles[2], identifier=iden_2 + handles=self.handles[self.sin_node_idx], identifier=iden_2 ), iden_2, ) diff --git a/exir/tests/test_pass_infra.py b/exir/tests/test_pass_infra.py index 0b9ba223f0d..c3788a5a38e 100644 --- a/exir/tests/test_pass_infra.py +++ b/exir/tests/test_pass_infra.py @@ -8,12 +8,12 @@ import unittest -import executorch.exir as exir - import torch +from executorch.exir import to_edge from executorch.exir.pass_manager import PassManager from executorch.exir.passes import ScalarToTensorPass from executorch.exir.passes.pass_registry import PassRegistry +from torch.export import export from torch.fx.passes.infra.pass_base import PassBase @@ -99,15 +99,16 @@ def replace_mul_with_div(gm: torch.fx.GraphModule) -> None: if node.op == "call_function" and node.target == torch.mul: node.target = torch.div - def f(x: torch.Tensor) -> torch.Tensor: - y = torch.add(x, x) - z = torch.add(y, x) - return z + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z f = ( - exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module + to_edge(export(AddModule(), (torch.randn(10),), strict=True)) + .exported_program() + .graph_module ) pm = PassManager(passes=[replace_add_with_mul, replace_mul_with_div]) self.assertEqual(len(pm.passes), 2) @@ -144,15 +145,16 @@ def introduce_call_module(gm: torch.fx.GraphModule) -> None: new_node = gm.graph.call_module("foo", (torch.randn(2),)) node.replace_all_uses_with(new_node) - def f(x: torch.Tensor) -> torch.Tensor: - y = torch.add(x, x) - z = torch.add(y, x) - return z + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z traced_f1 = ( - exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) - .to_edge() - .exported_program.graph_module + to_edge(export(AddModule(), (torch.randn(10),), strict=True)) + .exported_program() + .graph_module ) pm1 = PassManager( passes=[introduce_call_method], run_checks_after_each_pass=True @@ -162,13 +164,12 @@ def f(x: torch.Tensor) -> torch.Tensor: pm1(traced_f1) def test_pass_metadata(self) -> None: - def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y sample_inputs = (torch.randn(1, 3), torch.randn(1, 3)) - gm = exir.capture( - f, sample_inputs, exir.CaptureConfig() - ).exported_program.graph_module + gm = export(AddModule(), sample_inputs, strict=True).module() pass_result = ScalarToTensorPass()(gm) self.assertIsNotNone(pass_result) diff --git a/pytest-windows.ini b/pytest-windows.ini index 3dce0647367..6d5b5c2881e 100644 --- a/pytest-windows.ini +++ b/pytest-windows.ini @@ -78,7 +78,6 @@ addopts = # T200992559: Add torchao to ET as core dependency --ignore=examples/models/llama/tests/test_pre_quantization_transforms.py --ignore=exir/backend/test/demos - --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py --ignore=exir/backend/test/test_partitioner.py --ignore=exir/tests/test_common.py diff --git a/pytest.ini b/pytest.ini index f2a7abe06ed..05aea9d4da6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -41,7 +41,6 @@ addopts = --ignore=exir/verification/test/test_verifier.py # Ignore failing tests --ignore=exir/backend/test/demos/rpc/test_rpc.py - --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py --ignore=exir/backend/test/test_partitioner.py --ignore=exir/operator/test/test_operator.py From 31935131c46bd20e6b6662c237d85d10928f4f0a Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 12 Mar 2026 10:38:55 -0700 Subject: [PATCH 2/2] Remove deprecated XNNPACK capture utilities and migrate tests Summary: Delete capture_graph_for_xnnpack() and get_xnnpack_capture_config() which were only used in test files and relied on the deprecated exir.capture API. Migrate test_xnnpack_utils.py to use inline to_edge(export(...)) calls. Remove the deprecated exports from xnnpack/__init__.py. Differential Revision: D95605468 --- backends/xnnpack/__init__.py | 6 -- backends/xnnpack/test/test_xnnpack_utils.py | 67 +++++++++---------- backends/xnnpack/utils/configs.py | 17 +---- backends/xnnpack/utils/utils.py | 24 ------- .../size_analysis_tool_test.py | 16 ++--- 5 files changed, 42 insertions(+), 88 deletions(-) diff --git a/backends/xnnpack/__init__.py b/backends/xnnpack/__init__.py index b87dfab4f02..59aad34b543 100644 --- a/backends/xnnpack/__init__.py +++ b/backends/xnnpack/__init__.py @@ -12,14 +12,10 @@ # Exposed Configs in XNNPACK Package from .utils.configs import ( - get_xnnpack_capture_config, get_xnnpack_edge_compile_config, get_xnnpack_executorch_backend_config, ) -# Easy util functions -from .utils.utils import capture_graph_for_xnnpack - # XNNPACK Backend from .xnnpack_preprocess import XnnpackBackend @@ -27,8 +23,6 @@ "XnnpackDynamicallyQuantizedPartitioner", "XnnpackPartitioner", "XnnpackBackend", - "capture_graph_for_xnnpack", - "get_xnnpack_capture_config", "get_xnnpack_edge_compile_config", "get_xnnpack_executorch_backend_config", ] diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 5a6c529b497..0736cb01b0a 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -25,7 +25,6 @@ get_xnnpack_edge_compile_config, get_xnnpack_executorch_backend_config, ) -from executorch.backends.xnnpack.utils.utils import capture_graph_for_xnnpack # import the xnnpack backend implementation from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend @@ -35,7 +34,7 @@ from executorch.devtools.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from executorch.exir import ExecutorchProgram, ExirExportedProgram +from executorch.exir import EdgeProgramManager, to_edge from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.passes.spec_prop_pass import SpecPropPass @@ -157,6 +156,14 @@ def assert_outputs_equal(self, model_output, ref_output): torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) ) + def _capture_graph_for_xnnpack( + self, module: torch.nn.Module, sample_inputs: Tuple[torch.Tensor] + ) -> EdgeProgramManager: + return to_edge( + export(module, sample_inputs, strict=True), + compile_config=get_xnnpack_edge_compile_config(), + ).transform(*get_transform_passes()) + def lower_module_and_test_output( self, module: Any, @@ -167,7 +174,7 @@ def lower_module_and_test_output( # TODO: remove this after we migrate to use long term flow quantizer_api_test: bool = False, dump_bundled_program: bool = False, # for debugging, dump the generated bundled program file - ) -> ExirExportedProgram: + ) -> EdgeProgramManager: """ Helper testing function that takes a torch.nn.Module and lowers it to XNNPACK with the given sample inputs. It then runs the lowered module and compares its @@ -175,7 +182,7 @@ def lower_module_and_test_output( """ if quantizer_api_test: - assert isinstance(module, ExirExportedProgram) + assert isinstance(module, EdgeProgramManager) edge_program = module else: @@ -187,7 +194,7 @@ def __init__(self): def forward(self, *args): return self.one_module(*args) - edge_program = capture_graph_for_xnnpack(WrappedModule(), sample_inputs) + edge_program = self._capture_graph_for_xnnpack(WrappedModule(), sample_inputs) partitioner = None if quantized: @@ -201,35 +208,32 @@ def forward(self, *args): if use_partitioner: with validation_disabled(): delegated_program = edge_program - delegated_program.exported_program = to_backend( - edge_program.exported_program, partitioner + delegated_program._edge_programs["forward"] = to_backend( + edge_program.exported_program(), partitioner ) - executorch_program: ExecutorchProgram = delegated_program.to_executorch( + executorch_program = delegated_program.to_executorch( get_xnnpack_executorch_backend_config([SpecPropPass()]), ) else: - delegated_program = to_backend( - "XnnpackBackend", edge_program.exported_program, [] + delegated_module = to_backend( + "XnnpackBackend", edge_program.exported_program(), [] ) - exported_program: ExirExportedProgram = capture_graph_for_xnnpack( - delegated_program, sample_inputs + exported_program = self._capture_graph_for_xnnpack( + delegated_module, sample_inputs ) - executorch_program: ExecutorchProgram = exported_program.to_executorch( + executorch_program = exported_program.to_executorch( get_xnnpack_executorch_backend_config(), ) - # print("Graph Module with delegate:") - # delegated_module.print_readable() - # Assert the backend name is xnnpack self.assertEqual( - executorch_program.program.execution_plan[0].delegates[0].id, + executorch_program.executorch_program.execution_plan[0].delegates[0].id, XnnpackBackend.__name__, ) - ref_output = delegated_program(*sample_inputs) + ref_output = delegated_program.exported_program().module()(*sample_inputs) if dump_bundled_program: save_bundled_program( representative_inputs=sample_inputs, @@ -325,14 +329,9 @@ def quantize_and_test_model_with_quantizer( prepared = prepare_pt2e(m, quantizer) converted = convert_pt2e(prepared) - captured_program = exir.capture( - converted, - example_inputs, - config=exir.CaptureConfig(enable_aot=True, _unlift=True), - ) - - edge_program = captured_program.to_edge( - get_xnnpack_edge_compile_config() + edge_program = to_edge( + export(converted, example_inputs, strict=True), + compile_config=get_xnnpack_edge_compile_config(), ).transform(*get_transform_passes()) delegated_module = self.lower_module_and_test_output( module=edge_program, @@ -350,7 +349,7 @@ def quantize_and_test_model_with_quantizer( } for op in supported_ops: FileCheck().check_count(op, 0, exactly=True).run( - delegated_module.exported_program.graph_module.code + delegated_module.exported_program().graph_module.code ) def _test_xnnpack_dqlinear( @@ -398,12 +397,12 @@ def _test_xnnpack_dqlinear( prepared_linear, ) - captured_dqlinear = capture_graph_for_xnnpack(converted_linear, example_inputs) + captured_dqlinear = self._capture_graph_for_xnnpack(converted_linear, example_inputs) - captured_dqlinear.exported_program.graph_module.graph.print_tabular() + captured_dqlinear.exported_program().graph_module.graph.print_tabular() lowered_module = to_backend( - "XnnpackBackend", captured_dqlinear.exported_program, [] + "XnnpackBackend", captured_dqlinear.exported_program(), [] ) class CompositeModule(torch.nn.Module): @@ -417,19 +416,19 @@ def forward(self, x): composite_model = CompositeModule() composite_model(*example_inputs) - exported_program: ExirExportedProgram = capture_graph_for_xnnpack( + exported_program = self._capture_graph_for_xnnpack( composite_model, example_inputs ) - executorch_program: ExecutorchProgram = exported_program.to_executorch( + executorch_program = exported_program.to_executorch( get_xnnpack_executorch_backend_config(), ) self.assertEqual( - executorch_program.program.execution_plan[0].delegates[0].id, + executorch_program.executorch_program.execution_plan[0].delegates[0].id, XnnpackBackend.__name__, ) - ref_output = captured_dqlinear(*example_inputs) + ref_output = captured_dqlinear.exported_program().module()(*example_inputs) ref_output = composite_model(*example_inputs) print("ref_output:", ref_output) diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index 39314eb16d4..d407ea5bd5f 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -4,10 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +from typing import List import executorch.exir as exir -from executorch.exir import CaptureConfig from executorch.exir.pass_manager import PassType @@ -33,17 +32,3 @@ def get_xnnpack_executorch_backend_config( passes=additional_passes, extract_delegate_segments=True, ) - - -def get_xnnpack_capture_config( - dynamic_shape=False, - enable_aot: Optional[bool] = None, - unlift: Optional[bool] = None, -): - if enable_aot is None: - return CaptureConfig(enable_dynamic_shape=dynamic_shape) - else: - unlift = unlift if unlift is not None else enable_aot - return CaptureConfig( - enable_dynamic_shape=dynamic_shape, enable_aot=enable_aot, _unlift=unlift - ) diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index a8f3178f98f..a41d5bc634a 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -6,14 +6,8 @@ from typing import Any, cast, Optional, Tuple -import executorch.exir as exir import torch -from executorch.backends.xnnpack.utils.configs import ( - get_transform_passes, - get_xnnpack_capture_config, - get_xnnpack_edge_compile_config, -) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -28,24 +22,6 @@ from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node -### XNNPACK Capture ### -def capture_graph_for_xnnpack( - module: torch.nn.Module, - inputs: Tuple[torch.Tensor], - enable_aot: Optional[bool] = None, - unlift: Optional[bool] = None, -) -> exir.ExirExportedProgram: - return ( - exir.capture( - module, - inputs, - get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift), - ) - .to_edge(get_xnnpack_edge_compile_config()) - .transform(*get_transform_passes()) - ) - - ### XNNPACK Utils ### PERM_NCHW_TO_NHWC = [0, 2, 3, 1] PERM_NHWC_TO_NCHW = [0, 3, 1, 2] diff --git a/devtools/size_analysis_tool/size_analysis_tool_test.py b/devtools/size_analysis_tool/size_analysis_tool_test.py index 00e1c9567a4..016e7ceb718 100644 --- a/devtools/size_analysis_tool/size_analysis_tool_test.py +++ b/devtools/size_analysis_tool/size_analysis_tool_test.py @@ -11,15 +11,16 @@ XnnpackFloatingPointPartitioner, ) from executorch.backends.xnnpack.utils.configs import ( + get_xnnpack_edge_compile_config, get_xnnpack_executorch_backend_config, ) -from executorch.backends.xnnpack.utils.utils import capture_graph_for_xnnpack from executorch.devtools.size_analysis_tool.size_analysis_tool import ( generate_model_size_information, ) -from executorch.exir.backend.backend_api import to_backend, validation_disabled +from executorch.exir import to_edge from executorch.exir.passes.spec_prop_pass import SpecPropPass +from torch.export import export class SizeAnalysisToolTest(unittest.TestCase): @@ -52,14 +53,13 @@ def forward(self, x): test_input = torch.ones(size=(4, 7, 5, 6), dtype=torch.float) - edge_program = capture_graph_for_xnnpack(mm, (test_input,)) + edge_program = to_edge( + export(mm, (test_input,), strict=True), + compile_config=get_xnnpack_edge_compile_config(), + ) partitioner = XnnpackFloatingPointPartitioner() - with validation_disabled(): - delegated_program = edge_program - delegated_program.exported_program = to_backend( - edge_program.exported_program, partitioner - ) + delegated_program = edge_program.to_backend(partitioner) program = delegated_program.to_executorch( get_xnnpack_executorch_backend_config([SpecPropPass()]),