diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 9e35c3291fd..a48d88fa224 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -83,6 +83,7 @@ ) from executorch.exir.tensor import ( AddressSpaceOverflowException, + dim_order_from_stride, layout_enum, make_allocation_info, make_tensor_value, @@ -92,11 +93,9 @@ ) from executorch.exir.types import LeafValueSpec, ValueSpec from torch._subclasses.fake_tensor import FakeTensor - from torch.export.exported_program import ExportedProgram, ExportGraphSignature from torch.fx.node import Node from torch.utils import _pytree as pytree - from typing_extensions import TypeAlias @@ -1994,14 +1993,20 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool: # assign the storage of the placeholder spec to the storage of the real tensor if there is one if real_tensor is not None: - # for non-contigous tensors, convert to a contiguous one - real_tensor = real_tensor.contiguous() + # For tensors that are neither contiguous nor channels-last, convert to contiguous format. + if not ( + real_tensor.is_contiguous() + or real_tensor.is_contiguous(memory_format=torch.channels_last) + ): + real_tensor = real_tensor.contiguous() + # Weights cannot be views during emission or serialization if real_tensor.nbytes != real_tensor.untyped_storage().nbytes(): real_tensor = real_tensor.clone() spec.storage = real_tensor.untyped_storage() - + spec.stride = real_tensor.stride() + spec.dim_order = dim_order_from_stride(spec.stride) # User inputs and mutable buffers are not constants, other buffers or parameters are. if initialize_buffer and is_mutable_buffer: spec.const = True diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3ed432c1872..e0b265f7bf3 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -13,7 +13,6 @@ from typing import List, Optional, Tuple import executorch.exir as exir - import executorch.exir.schema as schema import executorch.exir.tests.models as models import pytest @@ -63,9 +62,7 @@ ) from executorch.runtime import Runtime from torch import nn - from torch._higher_order_ops import cond as torch_cond, map as torch_map - from torch.export import Dim, export from torch.export.experimental import _export_forward_backward @@ -2403,3 +2400,121 @@ def forward(self, x): # Compare results self.assertTrue(expected.shape == et_result.shape) self.assertTrue(torch.allclose(expected, et_result)) + + def test_emit_channels_last_constant(self) -> None: + """Test that channels-last constant tensors are emitted correctly. + + The dim_order and storage data must be consistent - if storage is in + channels-last physical layout, dim_order should reflect that, and vice versa. + """ + import struct + + class ChannelsLastConstant(nn.Module): + def __init__(self): + super().__init__() + # Create a constant tensor with channels-last memory format + self.constant = ( + torch.arange(2 * 3 * 4) + .reshape(1, 2, 3, 4) + .to(torch.float32) + .to(memory_format=torch.channels_last) + ) + + def forward(self): + return self.constant + + model = ChannelsLastConstant() + eager_out = model() + + program = to_edge(export(model, (), strict=True)).to_executorch() + # Run the model and verify output matches eager. + et_module = _load_for_executorch_from_buffer(program.buffer) + self.assertTrue(torch.allclose(eager_out, et_module()[0])) + + # Verify the dim_order is channels-last. + exec_plan = program.executorch_program.execution_plan[0] + output_idx = exec_plan.outputs[0] + tensor_val = exec_plan.values[output_idx].val + self.assertIsInstance(tensor_val, Tensor) + self.assertEqual(list(tensor_val.dim_order), [0, 2, 3, 1]) + + # Verify storage is in channels-last (NHWC) physical layout. + storage_bytes = program.executorch_program.constant_buffer[ + tensor_val.data_buffer_idx + ].storage + num_floats = len(storage_bytes) // 4 + storage_values = list(struct.unpack(f"{num_floats}f", storage_bytes)) + expected_nhwc_storage = [ + 0, + 12, + 1, + 13, + 2, + 14, + 3, + 15, + 4, + 16, + 5, + 17, + 6, + 18, + 7, + 19, + 8, + 20, + 9, + 21, + 10, + 22, + 11, + 23, + ] + self.assertEqual([int(v) for v in storage_values], expected_nhwc_storage) + + def test_emit_custom_dimorder(self) -> None: + """Test that non-contiguous constant tensors are made contiguous during emit.""" + import struct + + class TransposedConstant(nn.Module): + def __init__(self): + super().__init__() + # Original: shape (2, 16), strides (16, 1), contiguous + # After transpose: shape (16, 2), strides (1, 16), non-contiguous + self.constant = torch.arange(32).reshape(2, 16).float().transpose(1, 0) + + def forward(self): + return self.constant + + model = TransposedConstant() + eager_out = model() + + # Verify the tensor is not contiguous. + self.assertFalse(model.constant.is_contiguous()) + + program = to_edge(export(model, (), strict=True)).to_executorch() + # Run the model and verify output matches eager. + et_module = _load_for_executorch_from_buffer(program.buffer) + self.assertTrue(torch.allclose(eager_out, et_module()[0])) + + # Check that tensor is now contiguous. + exec_plan = program.executorch_program.execution_plan[0] + output_idx = exec_plan.outputs[0] + tensor_val = exec_plan.values[output_idx].val + self.assertIsInstance(tensor_val, Tensor) + self.assertEqual(list(tensor_val.dim_order), [0, 1]) + + # Verify storage is contiguous in physical memory. + storage_bytes = program.executorch_program.constant_buffer[ + tensor_val.data_buffer_idx + ].storage + num_floats = len(storage_bytes) // 4 + storage_values = list(struct.unpack(f"{num_floats}f", storage_bytes)) + + # The transposed tensor has shape (16, 2), so contiguous storage + # iterates row by row: [0, 16, 1, 17, 2, 18, ...]. + expected_storage = [] + for i in range(16): + for j in range(2): + expected_storage.append(j * 16 + i) + self.assertEqual([int(v) for v in storage_values], expected_storage)