diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4fc6bbf1cbc..8021c38b3a9 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -52,6 +52,7 @@ from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa +from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa from .decompose_index_select_to_gather_pass import ( # noqa DecomposeIndexSelectToGatherPass, ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index df2b85601bd..aea7239b5ff 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -59,6 +59,7 @@ DecomposeGluPass, DecomposeGroupedConvPass, DecomposeGroupNormPass, + DecomposeIndexCopyPass, DecomposeIndexSelectToGatherPass, DecomposeIntPowPass, DecomposeLayerNormPass, @@ -416,6 +417,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): # Transformation passes (pre scalar -> tensor) self.add_passes( [ + DecomposeIndexCopyPass(tfa_pass=True), DecomposeSelectScatterPass(tfa_pass=True), DecomposeSliceScatterPass(tfa_pass=True), ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), diff --git a/backends/arm/_passes/decompose_index_copy_pass.py b/backends/arm/_passes/decompose_index_copy_pass.py new file mode 100644 index 00000000000..3edee80bc54 --- /dev/null +++ b/backends/arm/_passes/decompose_index_copy_pass.py @@ -0,0 +1,33 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass +from executorch.backends.arm._passes.insert_int32_casts_after_int64_placeholders import ( + InsertInt32CastsAfterInt64PlaceholdersPass, +) +from executorch.exir.pass_base import ExportPass + + +class DecomposeIndexCopyPass(GetDecompositionPass): + """Decomposes aten.index_copy into aten.index_put, as well as it's + surrounding operators. + + This pass is intended to be called in transform_for_annotation to prepare + the graph for quantization. After quantization, this operator will be + prepared for lowering to TOSA using the RewriteIndexPut pass + + """ + + _passes_required_after: Set[Type[ExportPass]] = { + InsertInt32CastsAfterInt64PlaceholdersPass + } + + targeted_ops = [ + torch.ops.aten.index_copy.default, + torch.ops.aten.index_copy_.default, + ] diff --git a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py index 5d3322bf88a..7b1f9a99041 100644 --- a/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py +++ b/backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py @@ -36,8 +36,6 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass): # Key: op overload; Value: zero-based indices of positional args that must be i64. I64_INPUT_ARG_POSITIONS = { torch.ops.aten.one_hot.default: (0,), - torch.ops.aten.index_copy_.default: (2,), - torch.ops.aten.index_copy.default: (2,), } def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule): diff --git a/backends/arm/_passes/rewrite_index_put_pass.py b/backends/arm/_passes/rewrite_index_put_pass.py index 71998640b68..24752772129 100644 --- a/backends/arm/_passes/rewrite_index_put_pass.py +++ b/backends/arm/_passes/rewrite_index_put_pass.py @@ -85,7 +85,6 @@ def _expand_none_indices( source_shape: Sequence[int], indices: Iterable[Any], meta: NodeMetadata, - full_op, ) -> List[ProxyValue]: """Replace None indices with explicit ranges.""" expanded: List[ProxyValue] = [] @@ -189,9 +188,7 @@ def call_operator(self, op, args, kwargs, meta): "index_put with only None indices is not supported" ) - processed_indices = self._expand_none_indices( - source_shape, indices, plain_meta, full_op - ) + processed_indices = self._expand_none_indices(source_shape, indices, plain_meta) index_shapes = [tuple(idx.data.shape) for idx in processed_indices] try: broadcast_shape = torch.broadcast_shapes(*index_shapes) diff --git a/backends/arm/test/modules/test_static_cache.py b/backends/arm/test/modules/test_static_cache.py index e4ddea797df..a0e7d24cdac 100644 --- a/backends/arm/test/modules/test_static_cache.py +++ b/backends/arm/test/modules/test_static_cache.py @@ -124,25 +124,17 @@ def test_static_cache_tosa_FP(test_data): pipeline.run() -@pytest.mark.xfail( - reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported." -) @common.parametrize("test_data", test_configs) def test_static_cache_tosa_INT(test_data): module = StaticCacheModule(test_data).eval() pipeline = TosaPipelineINT[input_t]( - module, - module.get_inputs(), - aten_op=[], - exir_op=[], + module, module.get_inputs(), aten_op=[], exir_op=[], fold_quantize=False ) pipeline.run() @common.XfailIfNoCorstone300 -@pytest.mark.xfail( - reason="Quantization for StaticCache is not yet supported. Scatter operator is also not supported on U55." -) +@pytest.mark.xfail(reason="Scatter operator is not supported on U55.") @common.parametrize("test_data", test_configs) def test_static_cache_u55_INT(test_data): module = StaticCacheModule(test_data).eval() @@ -155,13 +147,17 @@ def test_static_cache_u55_INT(test_data): @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported." -) @common.parametrize("test_data", test_configs) def test_static_cache_u85_INT(test_data): module = StaticCacheModule(test_data).eval() - pipeline = EthosU85PipelineINT[input_t](module, module.get_inputs(), aten_ops=[]) + pipeline = EthosU85PipelineINT[input_t]( + module, + module.get_inputs(), + aten_ops=[], + fold_quantize=False, + ) + # U85: keep _to_dim_order_copy portable for int64->int32 cast of cache_position (not delegatable). + pipeline.tester.use_portable_ops = True pipeline.run() @@ -181,9 +177,6 @@ def test_static_cache_vgf_no_quant(test_data): @common.SkipIfNoModelConverter -@pytest.mark.xfail( - reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported." -) @common.parametrize("test_data", test_configs) def test_static_cache_vgf_quant(test_data): module = StaticCacheModule(test_data).eval() @@ -193,5 +186,7 @@ def test_static_cache_vgf_quant(test_data): aten_op=[], exir_op=[], quantize=True, + fold_quantize=False, + tosa_spec="TOSA-1.0+INT", ) pipeline.run() diff --git a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py index 8cae15927a0..1d6fb252804 100644 --- a/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py +++ b/backends/arm/test/passes/test_insert_int32_casts_after_int64_placeholders_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,7 +18,6 @@ class Int64InputModel(torch.nn.Module): - def forward(self, weights: torch.Tensor, indices: torch.Tensor): return torch.embedding(weights, indices) @@ -51,7 +50,10 @@ def test_insert_int32_casts_after_int64_placeholders_tosa_FP(): class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module): - aten_op = "torch.ops.aten.index_copy_.default" + aten_ops = [ + "torch.ops.dim_order_ops._to_dim_order_copy.default", + "torch.ops.aten.index_put_.default", + ] def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): return x.index_copy_(0, index, y) @@ -66,24 +68,22 @@ def get_inputs(self) -> input_t3: def test_insert_int32_casts_after_int64_placeholders_tosa_INT_upcast_for_index_copy_inplace(): module = UpcastToInt64ForIndexCopyInplaceModel() + + # In TOSA+INT index_copy_ decomposes to index_put_ + # There should also be cast from int64 to int32 pipeline = TosaPipelineINT[input_t3]( module, module.get_inputs(), - aten_op=module.aten_op, - ) - pipeline.pop_stage("check.quant_nodes") - pipeline.change_args( - "check_count.exir", - { - "torch.ops.higher_order.executorch_call_delegate": 0, - }, + aten_op=UpcastToInt64ForIndexCopyInplaceModel.aten_ops, ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() class UpcastToInt64ForIndexCopyModel(torch.nn.Module): - aten_op = "torch.ops.aten.index_copy.default" + aten_ops = [ + "torch.ops.dim_order_ops._to_dim_order_copy.default", + "torch.ops.aten.index_put.default", + ] def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor): return x.index_copy(0, index, y) @@ -98,17 +98,12 @@ def get_inputs(self) -> input_t3: def test_insert_int32_casts_after_int64_placeholders_tosa_INT_upcast_for_index_copy(): module = UpcastToInt64ForIndexCopyModel() + + # In TOSA+INT index_copy decomposes to index_put + # There should also be cast from int64 to int32 pipeline = TosaPipelineINT[input_t3]( module, module.get_inputs(), - aten_op=module.aten_op, - ) - pipeline.pop_stage("check.quant_nodes") - pipeline.change_args( - "check_count.exir", - { - "torch.ops.higher_order.executorch_call_delegate": 0, - }, + aten_op=UpcastToInt64ForIndexCopyModel.aten_ops, ) - pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run()