From eb732c06ee300d9018abf728bc20e87cadbaf8a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 12 Mar 2026 09:09:05 +0100 Subject: [PATCH 1/2] Arm backend: use full_like in LeakyReLU decomposition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace `full(x.shape, slope)` with `full_like(x, slope)`. This avoids capturing export-time shape metadata and lets the created tensor inherit dtype/device from the runtime input tensor, preventing shape and device mismatches across export flows (e.g. QAT/PT2E vs already-quantized export). Because of ths LeakyRelu pass need to go before full_like_to_full conversion pass. Signed-off-by: Måns Nilsson Change-Id: I2cbae091c34eb3b905e402f2f122b10742b0ae1d --- backends/arm/_passes/arm_pass_manager.py | 2 +- .../arm/_passes/decompose_leaky_relu_pass.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2cf8727ac8a..c234ce3db9b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -320,6 +320,7 @@ def _tosa_pipeline( [ ReplaceScalarWithTensorByProfilePass(), RewriteLeLtToGeGtPass(), + DecomposeLeakyReLUPass(), # Emits full_like so before ConvertFullLikeToFullPass ConvertFullLikeToFullPass(), MatchArgDtypePass(), UnsqueezeScalarPlaceholdersPass(exported_program), @@ -340,7 +341,6 @@ def _tosa_pipeline( FuseBatchNorm2dPass(exported_program), ConvertMmToBmmPass(), DecomposeGluPass(), - DecomposeLeakyReLUPass(), DecomposeDivPass(), # _safe_softmax results in a ReduceMax # which is not currently supported by TOSA in U55 diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index 5fbf8b9e035..d09efe92030 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -20,14 +20,14 @@ def _get_leaky_relu_ops(op) -> tuple: if op in edge_ops: return ( exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.add.Tensor, ) elif op in torch_ops: return ( torch.ops.aten.clamp.default, - torch.ops.aten.full.default, + torch.ops.aten.full_like.default, torch.ops.aten.mul.Tensor, torch.ops.aten.add.Tensor, ) @@ -42,7 +42,7 @@ class DecomposeLeakyReLUPass(ArmPass): Example: %op1 = clamp(x,0,None) (equivalent to max(0,x)) %op2 = clamp(x,None,0) (equivalent to min(0,x)) - %op3 = full(x.shape,slope) + %op3 = full_like(x,slope) %op4 = mul(%op3,%op2) %op5 = add(%op1,%op4) @@ -56,9 +56,7 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] slope = args[1] if len(args) > 1 else 0.01 - dtype = x.node.meta["val"].dtype - device = x.node.meta["val"].device - clamp, full, mul, add = _get_leaky_relu_ops(op) + clamp, full_like, mul, add = _get_leaky_relu_ops(op) op1 = super().call_operator( op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta ) @@ -66,9 +64,9 @@ def call_operator(self, op, args, kwargs, meta): op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta ) op3 = super().call_operator( - op=full, - args=(x.node.meta["val"].shape, slope), - kwargs={"dtype": dtype, "device": device}, + op=full_like, + args=(x, slope), + kwargs={}, meta=meta, ) op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta) From 3ec5eb1075cfe6759c079b399ca38c814430069f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 13 Mar 2026 08:06:55 +0100 Subject: [PATCH 2/2] Arm backend: Fix review comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Måns Nilsson Change-Id: If3947587dd806e349f4040fd0cb8b94214037145 --- backends/arm/_passes/decompose_leaky_relu_pass.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index d09efe92030..d9b5bbe96df 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -9,6 +9,9 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -48,7 +51,7 @@ class DecomposeLeakyReLUPass(ArmPass): """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta):