diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 2cf8727ac8a..c234ce3db9b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -320,6 +320,7 @@ def _tosa_pipeline( [ ReplaceScalarWithTensorByProfilePass(), RewriteLeLtToGeGtPass(), + DecomposeLeakyReLUPass(), # Emits full_like so before ConvertFullLikeToFullPass ConvertFullLikeToFullPass(), MatchArgDtypePass(), UnsqueezeScalarPlaceholdersPass(exported_program), @@ -340,7 +341,6 @@ def _tosa_pipeline( FuseBatchNorm2dPass(exported_program), ConvertMmToBmmPass(), DecomposeGluPass(), - DecomposeLeakyReLUPass(), DecomposeDivPass(), # _safe_softmax results in a ReduceMax # which is not currently supported by TOSA in U55 diff --git a/backends/arm/_passes/decompose_leaky_relu_pass.py b/backends/arm/_passes/decompose_leaky_relu_pass.py index 5fbf8b9e035..d9b5bbe96df 100644 --- a/backends/arm/_passes/decompose_leaky_relu_pass.py +++ b/backends/arm/_passes/decompose_leaky_relu_pass.py @@ -9,6 +9,9 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.convert_full_like_to_full_pass import ( + ConvertFullLikeToFullPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -20,14 +23,14 @@ def _get_leaky_relu_ops(op) -> tuple: if op in edge_ops: return ( exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.add.Tensor, ) elif op in torch_ops: return ( torch.ops.aten.clamp.default, - torch.ops.aten.full.default, + torch.ops.aten.full_like.default, torch.ops.aten.mul.Tensor, torch.ops.aten.add.Tensor, ) @@ -42,13 +45,13 @@ class DecomposeLeakyReLUPass(ArmPass): Example: %op1 = clamp(x,0,None) (equivalent to max(0,x)) %op2 = clamp(x,None,0) (equivalent to min(0,x)) - %op3 = full(x.shape,slope) + %op3 = full_like(x,slope) %op4 = mul(%op3,%op2) %op5 = add(%op1,%op4) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} def call_operator(self, op, args, kwargs, meta): if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta): @@ -56,9 +59,7 @@ def call_operator(self, op, args, kwargs, meta): x = args[0] slope = args[1] if len(args) > 1 else 0.01 - dtype = x.node.meta["val"].dtype - device = x.node.meta["val"].device - clamp, full, mul, add = _get_leaky_relu_ops(op) + clamp, full_like, mul, add = _get_leaky_relu_ops(op) op1 = super().call_operator( op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta ) @@ -66,9 +67,9 @@ def call_operator(self, op, args, kwargs, meta): op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta ) op3 = super().call_operator( - op=full, - args=(x.node.meta["val"].shape, slope), - kwargs={"dtype": dtype, "device": device}, + op=full_like, + args=(x, slope), + kwargs={}, meta=meta, ) op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)