Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
)
from executorch.exir.tensor import (
AddressSpaceOverflowException,
dim_order_from_stride,
layout_enum,
make_allocation_info,
make_tensor_value,
Expand Down Expand Up @@ -1994,14 +1995,17 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:

# assign the storage of the placeholder spec to the storage of the real tensor if there is one
if real_tensor is not None:
# for non-contigous tensors, convert to a contiguous one
real_tensor = real_tensor.contiguous()
# for non-contiguous or channels-last tensors, convert to contiguous.
if not (real_tensor.is_contiguous() or real_tensor.is_contiguous(memory_format=torch.channels_last)):
real_tensor = real_tensor.contiguous()

# Weights cannot be views during emission or serialization
if real_tensor.nbytes != real_tensor.untyped_storage().nbytes():
real_tensor = real_tensor.clone()

spec.storage = real_tensor.untyped_storage()

spec.stride = real_tensor.stride()
spec.dim_order = dim_order_from_stride(spec.stride)
# User inputs and mutable buffers are not constants, other buffers or parameters are.
if initialize_buffer and is_mutable_buffer:
spec.const = True
Expand Down
50 changes: 50 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2403,3 +2403,53 @@ def forward(self, x):
# Compare results
self.assertTrue(expected.shape == et_result.shape)
self.assertTrue(torch.allclose(expected, et_result))

def test_emit_channels_last_constant(self) -> None:
"""Test that channels-last constant tensors are emitted correctly.

The dim_order and storage data must be consistent - if storage is in
channels-last physical layout, dim_order should reflect that, and vice versa.
"""
import struct

class ChannelsLastConstant(nn.Module):
def __init__(self):
super().__init__()
# Create a constant tensor with channels-last memory format
self.constant = (
torch.arange(2 * 3 * 4)
.reshape(1, 2, 3, 4)
.to(torch.float32)
.to(memory_format=torch.channels_last)
)

def forward(self):
return self.constant

model = ChannelsLastConstant()
eager_out = model()

program = to_edge(
export(model, (), strict=True)
).to_executorch()

# Run the model and verify output matches eager.
et_module = _load_for_executorch_from_buffer(program.buffer)
self.assertTrue(torch.allclose(eager_out, et_module()[0]))

# Verify the dim_order is channels-last.
exec_plan = program.executorch_program.execution_plan[0]
output_idx = exec_plan.outputs[0]
tensor_val = exec_plan.values[output_idx].val
self.assertIsInstance(tensor_val, Tensor)
self.assertEqual(list(tensor_val.dim_order), [0, 2, 3, 1])

# Verify storage is in channels-last (NHWC) physical layout.
storage_bytes = program.executorch_program.constant_buffer[tensor_val.data_buffer_idx].storage
num_floats = len(storage_bytes) // 4
storage_values = list(struct.unpack(f"{num_floats}f", storage_bytes))
expected_nhwc_storage = [
0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17,
6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23,
]
self.assertEqual([int(v) for v in storage_values], expected_nhwc_storage)
Loading