Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_index_copy_pass import DecomposeIndexCopyPass # noqa
from .decompose_index_select_to_gather_pass import ( # noqa
DecomposeIndexSelectToGatherPass,
)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
DecomposeGluPass,
DecomposeGroupedConvPass,
DecomposeGroupNormPass,
DecomposeIndexCopyPass,
DecomposeIndexSelectToGatherPass,
DecomposeIntPowPass,
DecomposeLayerNormPass,
Expand Down Expand Up @@ -416,6 +417,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
# Transformation passes (pre scalar -> tensor)
self.add_passes(
[
DecomposeIndexCopyPass(tfa_pass=True),
DecomposeSelectScatterPass(tfa_pass=True),
DecomposeSliceScatterPass(tfa_pass=True),
ConvertInt64ConstOpsToInt32Pass(tfa_pass=True),
Expand Down
33 changes: 33 additions & 0 deletions backends/arm/_passes/decompose_index_copy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass
from executorch.backends.arm._passes.insert_int32_casts_after_int64_placeholders import (
InsertInt32CastsAfterInt64PlaceholdersPass,
)
from executorch.exir.pass_base import ExportPass


class DecomposeIndexCopyPass(GetDecompositionPass):
"""Decomposes aten.index_copy into aten.index_put, as well as it's
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring grammar: "it's" is the contraction for "it is"; here it should be the possessive "its" ("...as well as its surrounding operators").

Suggested change
"""Decomposes aten.index_copy into aten.index_put, as well as it's
"""Decomposes aten.index_copy into aten.index_put, as well as its

Copilot uses AI. Check for mistakes.
surrounding operators.

This pass is intended to be called in transform_for_annotation to prepare
the graph for quantization. After quantization, this operator will be
prepared for lowering to TOSA using the RewriteIndexPut pass

"""

_passes_required_after: Set[Type[ExportPass]] = {
InsertInt32CastsAfterInt64PlaceholdersPass
}

targeted_ops = [
torch.ops.aten.index_copy.default,
torch.ops.aten.index_copy_.default,
]
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
# Key: op overload; Value: zero-based indices of positional args that must be i64.
I64_INPUT_ARG_POSITIONS = {
torch.ops.aten.one_hot.default: (0,),
torch.ops.aten.index_copy_.default: (2,),
torch.ops.aten.index_copy.default: (2,),
}

def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
Expand Down
5 changes: 1 addition & 4 deletions backends/arm/_passes/rewrite_index_put_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def _expand_none_indices(
source_shape: Sequence[int],
indices: Iterable[Any],
meta: NodeMetadata,
full_op,
) -> List[ProxyValue]:
"""Replace None indices with explicit ranges."""
expanded: List[ProxyValue] = []
Expand Down Expand Up @@ -189,9 +188,7 @@ def call_operator(self, op, args, kwargs, meta):
"index_put with only None indices is not supported"
)

processed_indices = self._expand_none_indices(
source_shape, indices, plain_meta, full_op
)
processed_indices = self._expand_none_indices(source_shape, indices, plain_meta)
index_shapes = [tuple(idx.data.shape) for idx in processed_indices]
try:
broadcast_shape = torch.broadcast_shapes(*index_shapes)
Expand Down
29 changes: 12 additions & 17 deletions backends/arm/test/modules/test_static_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,17 @@ def test_static_cache_tosa_FP(test_data):
pipeline.run()


@pytest.mark.xfail(
reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported."
)
@common.parametrize("test_data", test_configs)
def test_static_cache_tosa_INT(test_data):
module = StaticCacheModule(test_data).eval()
pipeline = TosaPipelineINT[input_t](
module,
module.get_inputs(),
aten_op=[],
exir_op=[],
module, module.get_inputs(), aten_op=[], exir_op=[], fold_quantize=False
)
pipeline.run()


@common.XfailIfNoCorstone300
@pytest.mark.xfail(
reason="Quantization for StaticCache is not yet supported. Scatter operator is also not supported on U55."
)
@pytest.mark.xfail(reason="Scatter operator is not supported on U55.")
@common.parametrize("test_data", test_configs)
def test_static_cache_u55_INT(test_data):
module = StaticCacheModule(test_data).eval()
Expand All @@ -155,13 +147,17 @@ def test_static_cache_u55_INT(test_data):


@common.XfailIfNoCorstone320
@pytest.mark.xfail(
reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported."
)
@common.parametrize("test_data", test_configs)
def test_static_cache_u85_INT(test_data):
module = StaticCacheModule(test_data).eval()
pipeline = EthosU85PipelineINT[input_t](module, module.get_inputs(), aten_ops=[])
pipeline = EthosU85PipelineINT[input_t](
module,
module.get_inputs(),
aten_ops=[],
fold_quantize=False,
)
# U85: keep _to_dim_order_copy portable for int64->int32 cast of cache_position (not delegatable).
pipeline.tester.use_portable_ops = True
pipeline.run()


Expand All @@ -181,9 +177,6 @@ def test_static_cache_vgf_no_quant(test_data):


@common.SkipIfNoModelConverter
@pytest.mark.xfail(
reason="TODO(MLETORCH-1818): Quantization for StaticCache is not yet supported."
)
@common.parametrize("test_data", test_configs)
def test_static_cache_vgf_quant(test_data):
module = StaticCacheModule(test_data).eval()
Expand All @@ -193,5 +186,7 @@ def test_static_cache_vgf_quant(test_data):
aten_op=[],
exir_op=[],
quantize=True,
fold_quantize=False,
tosa_spec="TOSA-1.0+INT",
)
pipeline.run()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -18,7 +18,6 @@


class Int64InputModel(torch.nn.Module):

def forward(self, weights: torch.Tensor, indices: torch.Tensor):
return torch.embedding(weights, indices)

Expand Down Expand Up @@ -51,7 +50,10 @@ def test_insert_int32_casts_after_int64_placeholders_tosa_FP():


class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module):
aten_op = "torch.ops.aten.index_copy_.default"
aten_ops = [
"torch.ops.dim_order_ops._to_dim_order_copy.default",
"torch.ops.aten.index_put_.default",
]

def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor):
return x.index_copy_(0, index, y)
Expand All @@ -66,24 +68,22 @@ def get_inputs(self) -> input_t3:

def test_insert_int32_casts_after_int64_placeholders_tosa_INT_upcast_for_index_copy_inplace():
module = UpcastToInt64ForIndexCopyInplaceModel()

# In TOSA+INT index_copy_ decomposes to index_put_
# There should also be cast from int64 to int32
pipeline = TosaPipelineINT[input_t3](
module,
module.get_inputs(),
aten_op=module.aten_op,
)
pipeline.pop_stage("check.quant_nodes")
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
},
aten_op=UpcastToInt64ForIndexCopyInplaceModel.aten_ops,
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()


class UpcastToInt64ForIndexCopyModel(torch.nn.Module):
aten_op = "torch.ops.aten.index_copy.default"
aten_ops = [
"torch.ops.dim_order_ops._to_dim_order_copy.default",
"torch.ops.aten.index_put.default",
]

def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor):
return x.index_copy(0, index, y)
Expand All @@ -98,17 +98,12 @@ def get_inputs(self) -> input_t3:

def test_insert_int32_casts_after_int64_placeholders_tosa_INT_upcast_for_index_copy():
module = UpcastToInt64ForIndexCopyModel()

# In TOSA+INT index_copy decomposes to index_put
# There should also be cast from int64 to int32
pipeline = TosaPipelineINT[input_t3](
module,
module.get_inputs(),
aten_op=module.aten_op,
)
pipeline.pop_stage("check.quant_nodes")
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
},
aten_op=UpcastToInt64ForIndexCopyModel.aten_ops,
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
Loading