diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index a3933ffb993..878e0ddb7e0 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -751,6 +751,168 @@ def dfs(node_id, component): return connected_components +def _map_sequence_aot_output( + aot_intermediate_output: Sequence, + runtime_intermediate_output: Any, + negative_index: int, +) -> Tuple[Tuple, Tuple]: + """ + Handle the case when aot_intermediate_output is a Sequence. + + Returns: + Tuple of (aot_intermediate_output as tuple, mapped runtime output as tuple) + """ + if not isinstance(runtime_intermediate_output, Sequence): + raise TypeError( + "runtime intermediate output should be a sequence when aot intermediate output is a sequence" + ) + last_element = runtime_intermediate_output[negative_index] + # TODO: this (last_element = list) is never really the case because runtime never returns output as a list + # for delegate case. + if isinstance(last_element, list) and all( + isinstance(t, torch.Tensor) for t in last_element + ): + # If the last element is a list of tensors (delegate case) + aot_mapped_runtime_intermediate_output = last_element + elif isinstance(last_element, torch.Tensor): + # If the last element is a tensor, as is always the case for runtime. + # However, now we have a strange condition where aot_intermediate_output is a list of tensors + # while runtime_intermediate_output is a single tensor. So we should never really come here. + # TODO: fix this + aot_mapped_runtime_intermediate_output = runtime_intermediate_output + else: + raise ValueError( + "The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence" + ) + # List can't be used as a key, so convert to tuple + return tuple(aot_intermediate_output), tuple(aot_mapped_runtime_intermediate_output) + + +def _find_matching_runtime_output_by_shape_and_dtype( + aot_intermediate_output: torch.Tensor, + runtime_intermediate_output: Sequence, +) -> Any: + """ + Find the runtime output that matches the AOT output shape and dtype. + Used for multi-output operations (like native_layer_norm.out, native_dropout.out). + + Returns: + The matching runtime output, or runtime_intermediate_output[-1] as fallback. + """ + # Find all runtime outputs that match the AOT shape + matching_indices = [] + for idx, runtime_out in enumerate(runtime_intermediate_output): + if isinstance(runtime_out, torch.Tensor): + if runtime_out.shape == aot_intermediate_output.shape: + matching_indices.append(idx) + + if len(matching_indices) == 1: + # Exactly one shape match - use it (native multi-output case like layer_norm) + return runtime_intermediate_output[matching_indices[0]] + + if len(matching_indices) > 1: + # Multiple shape matches - try to distinguish by dtype + # For native_dropout, output is float and mask is bool; prefer matching dtype + dtype_matching_indices = [] + for idx in matching_indices: + runtime_out = runtime_intermediate_output[idx] + if isinstance(runtime_out, torch.Tensor): + if runtime_out.dtype == aot_intermediate_output.dtype: + dtype_matching_indices.append(idx) + + if len(dtype_matching_indices) == 1: + # Exactly one dtype match - use it (e.g., dropout case where mask is bool) + return runtime_intermediate_output[dtype_matching_indices[0]] + + # No unique match found, return the last element as fallback + return runtime_intermediate_output[-1] + + +def _map_non_sequence_aot_output( + aot_intermediate_output: Any, + runtime_intermediate_output: Any, + num_outputs: int, + negative_index: int, +) -> Any: + """ + Handle the case when aot_intermediate_output is NOT a Sequence. + + Returns: + The mapped runtime intermediate output. + """ + if not isinstance(runtime_intermediate_output, Sequence): + return runtime_intermediate_output + + # Use the last element of the runtime output as fallback if no match is found + aot_mapped_runtime_intermediate_output = runtime_intermediate_output[negative_index] + + # delegate runtime call and AOT intermediate is not a sequence. + # For multi-output operations (like native_layer_norm.out, native_dropout.out), + # the runtime captures all outputs but AOT only captures the primary output. + # We need to find the runtime output that matches the AOT output shape and dtype. + if ( + num_outputs == 1 + and len(runtime_intermediate_output) > 1 + and isinstance(aot_intermediate_output, torch.Tensor) + ): + aot_mapped_runtime_intermediate_output = ( + _find_matching_runtime_output_by_shape_and_dtype( + aot_intermediate_output, runtime_intermediate_output + ) + ) + + return aot_mapped_runtime_intermediate_output + + +def _process_single_runtime_output( + aot_list: List[Tuple[DebugHandle, Any]], + runtime_debug_handle: DebugHandle, + runtime_intermediate_output: Any, + num_outputs: int, + output_index: int, +) -> Optional[Tuple[Tuple[DebugHandle, Any], Tuple[DebugHandle, Any]]]: + """ + Process a single runtime output and map it to the corresponding AOT output. + + Returns: + A tuple of ((aot_debug_handle, aot_output), (runtime_debug_handle, runtime_output)) + or None if the mapping should be skipped. + """ + negative_index = -1 * (output_index + 1) + + # Combine aot debug handles into a single key + aot_combined_debug_handle, aot_intermediate_output = ( + _combine_aot_overlapped_intermediate_outputs( + aot_list, + (runtime_debug_handle, runtime_intermediate_output, num_outputs), + negative_index, + ) + ) + + if aot_combined_debug_handle == (-1,): + # Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match. + return None + + if isinstance(aot_intermediate_output, Sequence): + aot_intermediate_output, aot_mapped_runtime_intermediate_output = ( + _map_sequence_aot_output( + aot_intermediate_output, runtime_intermediate_output, negative_index + ) + ) + else: + aot_mapped_runtime_intermediate_output = _map_non_sequence_aot_output( + aot_intermediate_output, + runtime_intermediate_output, + num_outputs, + negative_index, + ) + + return ( + (aot_combined_debug_handle, aot_intermediate_output), + (runtime_debug_handle, aot_mapped_runtime_intermediate_output), + ) + + def map_runtime_aot_intermediate_outputs( aot_intermediate_outputs: Dict[DebugHandle, Any], runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]], @@ -789,84 +951,36 @@ def map_runtime_aot_intermediate_outputs( if nodes[node_id].source == NodeSource.RUNTIME ] - # Map only if both AOT and runtime data are present. - if len(aot_list) != 0 and len(runtime_list) != 0: - # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element. - # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes. - # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings. - if len(runtime_list) != 1: - raise ValueError( - f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}" - ) + if len(aot_list) == 0 or len(runtime_list) == 0: + # Skip this mapping if there are no AOT or runtime data. + continue - runtime_debug_handle, runtime_intermediate_output, num_outputs = ( - runtime_list[0] + # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element. + # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes. + # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings. + if len(runtime_list) != 1: + raise ValueError( + f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}" ) - # iterate through each of the output from runtime, - # get the corresponding debug handle - # and map it to the aot debug handle - # and create a dictionary that maps aot debug handle + aot output to - # runtime debug handle + runtime output - # Note this works only for delegate case for now. - for i in range(num_outputs): - - negative_index = -1 * (i + 1) - aot_mapped_runtime_intermediate_output = runtime_intermediate_output - # Combine aot debug handles into a single key - aot_combined_debug_handle, aot_intermediate_output = ( - _combine_aot_overlapped_intermediate_outputs( - aot_list, runtime_list[0], negative_index - ) - ) - - if aot_combined_debug_handle == (-1,): - # Skip this mapping if the aot combined debug handle and runtime debug handle do not exact match. - continue - - if isinstance(aot_intermediate_output, Sequence): - if not isinstance(runtime_intermediate_output, Sequence): - raise TypeError( - "runtime intermediate output should be a sequence when aot intermediate output is a sequence" - ) - last_element = runtime_intermediate_output[negative_index] - # TODO: this (last_element = list) is never really the case because runtime never returns output as a list - # for delegate case. - if isinstance(last_element, list) and all( - isinstance(t, torch.Tensor) for t in last_element - ): - # If the last element is a list of tensors (delegate case) - aot_mapped_runtime_intermediate_output = last_element - elif isinstance(last_element, torch.Tensor): - # If the last element is a tensor, as is always the case for runtime. - # However, now we have a strange condition where aot_intermediate_output is a list of tensors - # while runtime_intermediate_output is a single tensor. So we should never really come here. - # TODO: fix this - aot_mapped_runtime_intermediate_output = ( - runtime_intermediate_output - ) - else: - raise ValueError( - "The last element of runtime argument list must be a tensor or a list of tensors when aot intermediate output is a sequence" - ) - # List can't be used as a key, so convert to tuple - aot_intermediate_output = tuple(aot_intermediate_output) - aot_mapped_runtime_intermediate_output = tuple( - aot_mapped_runtime_intermediate_output - ) - elif isinstance(runtime_intermediate_output, Sequence): - # delegate runtime call and AOT intermediate is not a sequence, just take the last element from runtime list - aot_mapped_runtime_intermediate_output = ( - runtime_intermediate_output[negative_index] - ) - - # Create a mapping between runtime and aot - aot_runtime_mapping[ - (aot_combined_debug_handle, aot_intermediate_output) - ] = ( - runtime_debug_handle, - aot_mapped_runtime_intermediate_output, - ) + runtime_debug_handle, runtime_intermediate_output, num_outputs = runtime_list[0] + # iterate through each of the output from runtime, + # get the corresponding debug handle + # and map it to the aot debug handle + # and create a dictionary that maps aot debug handle + aot output to + # runtime debug handle + runtime output + # Note this works only for delegate case for now. + for i in range(num_outputs): + result = _process_single_runtime_output( + aot_list, + runtime_debug_handle, + runtime_intermediate_output, + num_outputs, + i, + ) + if result is not None: + aot_key, runtime_value = result + aot_runtime_mapping[aot_key] = runtime_value return aot_runtime_mapping diff --git a/devtools/inspector/tests/TARGETS b/devtools/inspector/tests/TARGETS index 048c8f6f791..b50bd79d4d0 100644 --- a/devtools/inspector/tests/TARGETS +++ b/devtools/inspector/tests/TARGETS @@ -11,14 +11,17 @@ python_unittest( ci.buckconfig("executorch.event_tracer_enabled", "true"), ), deps = [ + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/utils:xnnpack_utils", "//executorch/devtools:lib", "//executorch/devtools/debug_format:et_schema", "//executorch/devtools/etdump:schema_flatcc", "//executorch/devtools/etrecord/tests:etrecord_test_library", "//executorch/devtools/inspector:inspector", "//executorch/devtools/inspector:lib", - "//executorch/exir:lib", "//executorch/devtools/inspector/tests:inspector_test_utils", + "//executorch/exir:lib", + "//executorch/runtime:runtime", ], ) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 93a74915e84..a42077394a4 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -795,7 +795,167 @@ def compare(self, a, b): # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0 self.assertEqual(df.iloc[1]["gap"][0], 1.0) - @unittest.skip("ci config values are not propagated") + def test_transformer_block_xnnpack_numeric_gap_within_tolerance(self): + """ + Test that the numeric gap between AOT and runtime intermediate outputs + for a ViT model lowered to XNNPACK delegate is within acceptable tolerance. + + This test verifies that when a Vision Transformer (ViT) model is exported + and lowered to XNNPACK, the intermediate outputs during runtime closely + match the expected AOT outputs, with gaps remaining within a small range. + """ + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + from executorch.backends.xnnpack.utils.configs import ( + get_xnnpack_edge_compile_config, + ) + from executorch.runtime import Method, Program, Runtime, Verification + from torch import nn as nn + + class SingleBlockTransformer(nn.Module): + def __init__( + self, + vocab_size: int, + d_model: int = 256, + nhead: int = 8, + dim_feedforward: int = 1024, + max_len: int = 512, + dropout: float = 0.1, + ): + super().__init__() + self.d_model = d_model + + self.tok_emb = nn.Embedding(vocab_size, d_model) + self.pos_emb = nn.Embedding(max_len, d_model) + + # Single transformer encoder block + self.block = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + batch_first=True, # input: (B, T, C) + activation="gelu", + norm_first=True, + ) + + self.ln = nn.LayerNorm(d_model) + self.head = nn.Linear(d_model, vocab_size, bias=False) + + def forward( + self, input_ids: torch.Tensor, attn_mask: torch.Tensor | None = None + ): + """ + input_ids: (B, T) LongTensor + attn_mask (optional): (B, T) where 1/True = keep, 0/False = pad + """ + B, T = input_ids.shape + pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T) + + x = self.tok_emb(input_ids) + self.pos_emb(pos) # (B, T, d_model) + + # Convert padding mask to TransformerEncoderLayer's expected format: + # src_key_padding_mask: (B, T) with True = PAD (masked out) + src_key_padding_mask = None + if attn_mask is not None: + src_key_padding_mask = ~attn_mask.to(torch.bool) + + x = self.block( + x, src_key_padding_mask=src_key_padding_mask + ) # (B, T, d_model) + x = self.ln(x) + logits = self.head(x) # (B, T, vocab_size) + return logits + + vocab_size = 5000 + model = SingleBlockTransformer( + vocab_size=vocab_size, d_model=256, nhead=8, max_len=128 + ) + model_inputs = ( + torch.randint(0, vocab_size, (1, 32)), + torch.ones(1, 32, dtype=torch.bool), + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Export and lower model to XNNPACK delegate + aten_model: ExportedProgram = export(model, model_inputs, strict=True) + edge_program_manager = to_edge_transform_and_lower( + aten_model, + partitioner=[XnnpackPartitioner()], + compile_config=get_xnnpack_edge_compile_config(), + generate_etrecord=True, + ) + + et_program_manager: ExecutorchProgramManager = ( + edge_program_manager.to_executorch() + ) + + pte_path = os.path.join(temp_dir, "model.pte") + et_program_manager.save(pte_path) + + # Dump ETRecord containing debug info for export progress + etrecord = et_program_manager.get_etrecord() + + # Set the input for numerical discrepancy detection + etrecord.update_representative_inputs(model_inputs) + etrecord_path = os.path.join(temp_dir, "etrecord.bin") + etrecord.save(etrecord_path) + + # Load and run PTE through Runtime API + et_runtime: Runtime = Runtime.get() + program: Program = et_runtime.load_program( + pte_path, + verification=Verification.Minimal, + enable_etdump=True, + debug_buffer_size=1024 * 1024 * 1024, # 1GB + ) + + forward: Method = program.load_method("forward") + forward.execute(model_inputs) + + # Dump ETDump recording execution data + etdump_path = os.path.join(temp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(temp_dir, "debug_buffer.bin") + program.write_etdump_result_to_file(etdump_path, debug_buffer_path) + + # Create Inspector and calculate numeric gap + try: + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + except FileNotFoundError as e: + new_message = f"{e} You likely need to run the test with --config executorch.event_tracer_enabled=true" + raise RuntimeError(new_message) from e + + df: pd.DataFrame = inspector.calculate_numeric_gap("MSE") + + # Verify we got results + self.assertIsNotNone(df) + self.assertGreater(len(df), 0) + + # Define tolerance threshold for numeric gap + TOLERANCE = 1e-1 + + # Check that each gap value is within acceptable tolerance + for idx, row in df.iterrows(): + gap_value = row["gap"] + # Handle case where gap might be a list + if isinstance(gap_value, list): + gap_value = gap_value[0] if gap_value else 0.0 + + runtime_ops = row["runtime_ops"] + aot_ops = row["aot_ops"] + + self.assertLessEqual( + gap_value, + TOLERANCE, + f"Gap at index {idx} ( aot_ops: {aot_ops}, runtime_ops: {runtime_ops}) is {gap_value}, " + f"which exceeds tolerance {TOLERANCE}", + ) + def test_intermediate_tensor_comparison_with_torch_export(self): """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower."""