diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py index d54b549409f..ed45998dc56 100644 --- a/backends/transforms/replace_scalar_with_tensor.py +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -11,6 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass +from torch._ops import OpOverload class ReplaceScalarWithTensorArgPass(ExportPass): @@ -46,6 +47,11 @@ def __init__( super().__init__() def get_replacement(self, op, args, kwargs, meta): + if isinstance(op, OpOverload): + full_op = torch.ops.aten.full.default + else: + full_op = exir_ops.edge.aten.full.default + return super().call_operator( # Replace with .Tensor variant. op=self.scalar_to_tensor_ops[op], @@ -54,10 +60,10 @@ def get_replacement(self, op, args, kwargs, meta): args[0], # Scalar arg - replace with aten.full tensor. super().call_operator( - exir_ops.edge.aten.full.default, + full_op, args=( (1,), - args[1], + float(args[1]), ), kwargs={ "dtype": args[0].to_tensor().dtype,