Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def _tosa_pipeline(
[
ReplaceScalarWithTensorByProfilePass(),
RewriteLeLtToGeGtPass(),
DecomposeLeakyReLUPass(), # Emits full_like so before ConvertFullLikeToFullPass
ConvertFullLikeToFullPass(),
MatchArgDtypePass(),
UnsqueezeScalarPlaceholdersPass(exported_program),
Expand All @@ -340,7 +341,6 @@ def _tosa_pipeline(
FuseBatchNorm2dPass(exported_program),
ConvertMmToBmmPass(),
DecomposeGluPass(),
DecomposeLeakyReLUPass(),
DecomposeDivPass(),
# _safe_softmax results in a ReduceMax
# which is not currently supported by TOSA in U55
Expand Down
21 changes: 11 additions & 10 deletions backends/arm/_passes/decompose_leaky_relu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand All @@ -20,14 +23,14 @@ def _get_leaky_relu_ops(op) -> tuple:
if op in edge_ops:
return (
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.add.Tensor,
)
elif op in torch_ops:
return (
torch.ops.aten.clamp.default,
torch.ops.aten.full.default,
torch.ops.aten.full_like.default,
torch.ops.aten.mul.Tensor,
torch.ops.aten.add.Tensor,
)
Expand All @@ -42,33 +45,31 @@ class DecomposeLeakyReLUPass(ArmPass):
Example:
%op1 = clamp(x,0,None) (equivalent to max(0,x))
%op2 = clamp(x,None,0) (equivalent to min(0,x))
%op3 = full(x.shape,slope)
%op3 = full_like(x,slope)
%op4 = mul(%op3,%op2)
%op5 = add(%op1,%op4)

"""

_passes_required_after: Set[Type[ExportPass]] = set()
_passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass}

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_ops + torch_ops) or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)

x = args[0]
slope = args[1] if len(args) > 1 else 0.01
dtype = x.node.meta["val"].dtype
device = x.node.meta["val"].device
clamp, full, mul, add = _get_leaky_relu_ops(op)
clamp, full_like, mul, add = _get_leaky_relu_ops(op)
op1 = super().call_operator(
op=clamp, args=(x, 0, None), kwargs=kwargs, meta=meta
)
op2 = super().call_operator(
op=clamp, args=(x, None, 0), kwargs=kwargs, meta=meta
)
op3 = super().call_operator(
op=full,
args=(x.node.meta["val"].shape, slope),
kwargs={"dtype": dtype, "device": device},
op=full_like,
args=(x, slope),
kwargs={},
meta=meta,
)
op4 = super().call_operator(op=mul, args=(op3, op2), kwargs=kwargs, meta=meta)
Expand Down
Loading