From 4283c98c864c51f735a399a6fcd3d60c4d5bc10e Mon Sep 17 00:00:00 2001 From: Will Guo Date: Sun, 1 Feb 2026 22:43:57 +0000 Subject: [PATCH 01/14] Integrate Automated QDQ placement tool - part 3.2 Signed-off-by: Will Guo --- .../onnx/quantization/autotune/__init__.py | 8 + .../onnx/quantization/autotune/autotuner.py | 1092 +++++++++++++++++ .../autotune/autotune/test_autotuner.py | 345 ++++++ 3 files changed, 1445 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/autotuner.py create mode 100644 tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py index 91e86889f..a722cabbb 100644 --- a/modelopt/onnx/quantization/autotune/__init__.py +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -22,11 +22,15 @@ # Core data structures from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark +from .autotuner import QDQAutotuner from .common import ( AutotunerError, AutotunerNotInitializedError, + Config, InsertionScheme, InvalidSchemeError, + PatternCache, + PatternSchemes, Region, RegionType, ) @@ -45,9 +49,13 @@ "ChildRegionInputInsertionPoint", "ChildRegionOutputInsertionPoint", "CombinedRegionSearch", + "Config", "InsertionScheme", "InvalidSchemeError", "NodeInputInsertionPoint", + "PatternCache", + "PatternSchemes", + "QDQAutotuner", "Region", "RegionPattern", "RegionType", diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py new file mode 100644 index 000000000..9eb8724dc --- /dev/null +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -0,0 +1,1092 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" + +import copy +import os +import random +from collections import deque +from datetime import datetime, timezone + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import ( + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + merge_resolved_insertion_points, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch +from modelopt.onnx.quantization.fp8 import int8_to_fp8 +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class QDQAutotunerBase: + """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + + def __init__(self, model: onnx.ModelProto | gs.Graph): + """Initialize the autotuner with an ONNX model.""" + if isinstance(model, onnx.ModelProto): + self.onnx_model = model + elif isinstance(model, gs.Graph): + self.onnx_model = gs.export_onnx(model) + else: + raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") + + self.graph = self._copy_graph() + self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.regions: list[Region] = [] + self.current_profile_region: Region | None = None + self.profiled_patterns: list[PatternSchemes] = [] + self.current_profile_pattern_schemes: PatternSchemes | None = None + self.current_insertion_scheme_index: int | None = None + self.config = Config() + self.initialized = False + self.baseline_latency_ms: float | None = None + self.pattern_cache: PatternCache | None = None + + logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuning session with configuration and pattern cache.""" + if config is not None: + self.config = config + + if pattern_cache is None: + pattern_cache = PatternCache( + minimum_distance=self.config.pattern_cache_minimum_distance, + max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, + ) + self.pattern_cache = pattern_cache + + logger.debug( + f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " + f"{pattern_cache.total_schemes} schemes" + ) + + self.initialized = False + self.baseline_latency_ms = None + self.profiled_patterns.clear() + self.regions.clear() + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + logger.info("Initializing autotuner") + logger.debug( + f"Configuration: q_scale={self.config.default_q_scale}, " + f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" + ) + + self.initialized = True + + def set_profile_region(self, region: Region | None, commit: bool = True) -> None: + """Set the target region for profiling and scheme generation.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if commit: + if self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + if commit or region is None: + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + if region is None: + return + + if region not in self.regions: + raise ValueError(f"Region {region.id} not found in regions") + + region_pattern = RegionPattern.from_region(region, self.graph) + + if self._is_region_profiled(region): + logger.info(f"Skipping region {region.id} (pattern already profiled)") + logger.debug(f"Pattern signature: {region_pattern.signature}") + return + + pattern_schemes = None + num_seeded = 0 + + if self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) + + if cache_schemes is not None and len(cache_schemes.schemes) > 0: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + else: + logger.debug("No pattern cache entries for this region") + + if pattern_schemes is None: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + logger.debug("Initialized with empty scheme collection") + + self.current_profile_region = region + self.current_profile_pattern_schemes = pattern_schemes + + mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" + logger.info( + f"Profiling region {region.id} [pattern mode, level {region.level}, " + f"size {region.get_size_of_region_and_descendants()}, {mode_info}]" + ) + logger.debug(f"Pattern signature: {region_pattern.signature}") + + def generate(self) -> int: + """Generate a new Q/DQ insertion scheme for the current pattern or region.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + elif self.current_profile_pattern_schemes is None: + raise InvalidSchemeError("No region selected. Call set_profile_region() first.") + + pattern_schemes = self.current_profile_pattern_schemes + cached_schemes = [ + (idx, scheme) + for idx, scheme in enumerate(pattern_schemes.schemes) + if not scheme.is_profiled + ] + + if cached_schemes: + scheme_index, cached_scheme_data = cached_schemes[0] + num_node_points = len(cached_scheme_data.node_inputs) + num_region_composite_points = len(cached_scheme_data.child_region_inputs) + num_region_output_points = len(cached_scheme_data.region_outputs) + total_points = num_node_points + num_region_composite_points + num_region_output_points + + logger.info( + f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Cached scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} + max_attempts = getattr(self.config, "maximum_generation_attempts", 100) + + logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") + + for attempts in range(max_attempts): + new_scheme = self._generate_next_insertion_sample() + if new_scheme.hash not in known_schemes and not new_scheme.error: + pattern_schemes.schemes.append(new_scheme) + scheme_index = len(pattern_schemes.schemes) - 1 + num_node_points = len(new_scheme.node_inputs) + num_region_composite_points = len(new_scheme.child_region_inputs) + num_region_output_points = len(new_scheme.region_outputs) + total_points = ( + num_node_points + num_region_composite_points + num_region_output_points + ) + + logger.info( + f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points " + f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") + return -1 + + def export_onnx( + self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False + ) -> bytes: + """Export ONNX model with Q/DQ nodes inserted according to tested schemes.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + output_desc = output_path if output_path is not None else "" + original_quant_type = self.config.default_quant_type + needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" + resolved_insertion_points = set() + + logger.debug( + f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " + f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" + ) + + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + self.config.default_quant_type = "int8" + + if insert_qdq: + matched_regions = 0 + + logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") + + for region in self.regions: + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + current_scheme = self.current_profile_pattern_schemes + if current_scheme is None or pattern != current_scheme.pattern: + pass + elif best: + current_scheme = current_scheme.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + assert scheme_index < len(current_scheme.schemes), ( + f"Invalid scheme index: {scheme_index}" + ) + current_scheme = current_scheme.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if pattern_schemes is not None: + schemes = pattern_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + continue + + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) + resolved_insertion_points.difference_update(all_region_ips) + excluded_tensors = all_region_ips - resolved_insertion_points + if excluded_tensors: + logger.debug( + f" → Excluded {len(excluded_tensors)} overlapping insertion points" + ) + + new_ips = pattern.matches(region, self.graph, current_scheme) + if new_ips: + resolved_insertion_points.update(new_ips) + matched_regions += 1 + logger.debug(f" → Added {len(new_ips)} insertion points") + + logger.debug( + f"Matched {matched_regions}/{len(self.regions)} regions, " + f"total {len(resolved_insertion_points)} unique insertion points" + ) + + graph_copy = self._copy_graph() + unique_tensors = len(resolved_insertion_points) + + logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") + + if insert_qdq and resolved_insertion_points: + self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) + + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) + + if insert_qdq and resolved_insertion_points: + self._fix_zero_point_initializers(model) + + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + + self.config.default_quant_type = original_quant_type + model_bytes = model.SerializeToString() + quant_type_str = "baseline" + output_dest = "" + + if insert_qdq: + quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" + + if output_path is not None: + onnx.save(model, output_path) + output_dest = f" → {output_path}" + + logger.info( + f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs {output_dest}" + ) + return model_bytes + + def submit(self, latency_ms: float, success: bool = True) -> None: + """Submit performance measurement for the most recently generated scheme.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if self.baseline_latency_ms is None: + self.baseline_latency_ms = latency_ms + logger.info(f"Baseline latency: {latency_ms:.3f} ms") + return + + if self.current_profile_pattern_schemes is None: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + schemes_collection = self.current_profile_pattern_schemes + if not schemes_collection.schemes: + raise InvalidSchemeError("No schemes available. Call generate() first.") + + pattern_schemes = schemes_collection + + if self.current_insertion_scheme_index is not None: + scheme_index = self.current_insertion_scheme_index + if scheme_index >= len(pattern_schemes.schemes): + raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") + scheme = pattern_schemes.schemes[scheme_index] + else: + scheme = pattern_schemes.schemes[-1] + scheme_index = len(pattern_schemes.schemes) - 1 + + scheme.latency_ms = latency_ms + scheme.error = not success + scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() + display_index = scheme_index + 1 + + if not success: + logger.warning( + f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" + ) + logger.debug("Marking scheme with error flag") + return + + speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 + + logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") + logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") + + old_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + pattern_schemes.schemes.sort( + key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") + ) + new_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + + if new_best < old_best: + new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 + logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") + logger.debug(f"Previous best: {old_best:.3f} ms") + + if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.debug( + f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def save_state(self, output_path: str) -> None: + """Save complete autotuner state to a YAML file for later reuse.""" + current_pattern_sig = None + if self.current_profile_pattern_schemes is not None: + current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature + + state = { + "baseline_latency_ms": self.baseline_latency_ms, + "current_profile_pattern_schemes_signature": current_pattern_sig, + "config": { + "default_q_scale": self.config.default_q_scale, + "default_q_zero_point": self.config.default_q_zero_point, + "default_quant_type": self.config.default_quant_type, + "verbose": self.config.verbose, + }, + "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], + } + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + num_patterns = len(self.profiled_patterns) + total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) + + logger.info( + f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" + ) + logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") + + if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: + base_path, ext = os.path.splitext(output_path) + cache_path = f"{base_path}_pattern_cache{ext}" + self.pattern_cache.save(cache_path) + + logger.info(f"Saved pattern cache → {cache_path}") + logger.debug( + f"Cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def load_state(self, input_path: str) -> None: + """Load autotuner state from a previously saved YAML file.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + with open(input_path) as f: + state = yaml.safe_load(f) + + if state.get("baseline_latency_ms") is not None: + self.baseline_latency_ms = state["baseline_latency_ms"] + logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") + + if "config" in state: + config_data = state["config"] + if "default_q_scale" in config_data: + self.config.default_q_scale = config_data["default_q_scale"] + if "default_q_zero_point" in config_data: + self.config.default_q_zero_point = config_data["default_q_zero_point"] + if "default_quant_type" in config_data: + self.config.default_quant_type = config_data["default_quant_type"] + if "verbose" in config_data: + self.config.verbose = config_data["verbose"] + logger.debug(f"Config merged: quant_type={self.config.default_quant_type}") + + if "patterns" in state: + num_loaded_patterns = 0 + num_loaded_schemes = 0 + + for pattern_data in state["patterns"]: + try: + pattern_schemes = PatternSchemes.from_dict(pattern_data) + + if pattern_schemes.schemes: + self.profiled_patterns.append(pattern_schemes) + num_loaded_patterns += 1 + num_loaded_schemes += len(pattern_schemes.schemes) + else: + logger.debug( + f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." + ) + + except Exception as e: # noqa: PERF203 + logger.warning(f"Failed to load pattern: {e}") + continue + + logger.info( + f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " + f"{num_loaded_schemes} schemes)" + ) + + base_path, ext = os.path.splitext(input_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + if os.path.exists(cache_path): + try: + loaded_cache = PatternCache.load(cache_path) + + if self.pattern_cache is not None: + for pattern_schemes in loaded_cache.pattern_schemes: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + else: + self.pattern_cache = loaded_cache + logger.info( + f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " + f"{loaded_cache.total_schemes} schemes)" + ) + except Exception as e: + logger.warning(f"Failed to load pattern cache: {e}") + else: + logger.debug(f"No pattern cache file at {cache_path}") + + def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if isinstance(quantized_tensors, list): + quantized_tensors = set(quantized_tensors) + + logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") + logger.debug(f"Processing {len(self.regions)} regions") + + if self.pattern_cache is None: + logger.warning("Pattern cache not initialized, skipping import") + return + + patterns_before = self.pattern_cache.num_patterns + schemes_before = self.pattern_cache.total_schemes + + for region in self.regions: + self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) + + patterns_added = self.pattern_cache.num_patterns - patterns_before + schemes_added = self.pattern_cache.total_schemes - schemes_before + + logger.info( + f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" + ) + logger.debug( + f"Total cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def _compute_convergence_metrics( + self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None + ) -> tuple[int | None, float | None]: + """Compute convergence metrics for a collection of schemes.""" + samples_before_best = None + time_to_best = None + + if not best_scheme or not best_scheme.profile_timestamp: + return samples_before_best, time_to_best + + schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] + + if not schemes_with_time: + return samples_before_best, time_to_best + + schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") + + try: + best_position = next( + i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash + ) + samples_before_best = best_position + + first_ts = schemes_with_time[0].profile_timestamp + best_ts = best_scheme.profile_timestamp + assert first_ts is not None and best_ts is not None + first_timestamp = datetime.fromisoformat(first_ts) + best_timestamp = datetime.fromisoformat(best_ts) + time_to_best = (best_timestamp - first_timestamp).total_seconds() + except (StopIteration, ValueError): + pass + + return samples_before_best, time_to_best + + def _is_region_profiled(self, region: Region) -> bool: + """Check if a region's pattern has already been fully profiled.""" + + def match_pattern(pattern: PatternSchemes, region: Region) -> bool: + """Check if a pattern matches a region.""" + if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): + return False + return not any(not scheme.is_profiled for scheme in pattern.schemes) + + return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + + def _mutate_insertion_points( + self, base_points, all_points, point_type: str, max_mutations: int + ) -> list: + """Mutate a set of insertion points by adding, removing, or both.""" + key_fn = { + "node input points": lambda p: (p.node_index, p.input_index), + "region composite points": lambda p: (p.region_index, p.input_index), + "region output points": lambda p: (p.region_index, p.node_index, p.output_index), + }.get(point_type) + + if not key_fn: + return [] + + current_points = set(base_points) + initial_count = len(current_points) + mutation_type = random.choice(["add", "remove", "both"]) + + if mutation_type in ["add", "both"] and len(current_points) < len(all_points): + all_keys = {key_fn(p) for p in all_points} + available_keys = all_keys - current_points + if available_keys: + max_add = min(max_mutations, len(available_keys)) + num_to_add = random.randint(1, max_add) + to_add = random.sample(list(available_keys), num_to_add) + current_points.update(to_add) + + if mutation_type in ["remove", "both"] and current_points: + max_remove = min(max_mutations, len(current_points)) + num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 + num_to_remove = min(num_to_remove, len(current_points)) + to_remove = random.sample(list(current_points), num_to_remove) + for p in to_remove: + current_points.discard(p) + + logger.debug( + f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" + ) + + return [p for p in all_points if key_fn(p) in current_points] + + def _generate_next_insertion_sample(self) -> InsertionScheme: + """Generate a new insertion scheme by mutating top performers.""" + if self.current_profile_region is None: + return InsertionScheme() + + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + return InsertionScheme() + + region = self.current_profile_region + pattern_schemes = schemes_collection + + if not isinstance(schemes_collection, PatternSchemes) or schemes_collection.pattern is None: + return InsertionScheme() + pattern = schemes_collection.pattern + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + + logger.debug( + f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " + f"{len(full_insertion_scheme.child_region_inputs)} region composite, " + f"{len(full_insertion_scheme.region_outputs)} region output" + ) + + top_percent = getattr(self.config, "top_percent_to_mutate", 0.1) + minimum_schemes = getattr(self.config, "minimum_schemes_to_mutate", 1) + + measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] + measured_schemes.sort(key=lambda s: s.latency_ms) + + num_top_schemes = max( + int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) + ) + top_schemes = measured_schemes[:num_top_schemes] + + if len(top_schemes) == 0: + logger.debug("No measured schemes yet, generating baseline (empty) scheme") + return InsertionScheme() + + base_scheme = random.choice(top_schemes) + total_base_points = ( + len(base_scheme.node_inputs) + + len(base_scheme.child_region_inputs) + + len(base_scheme.region_outputs) + ) + logger.debug( + f"Mutating from top {len(top_schemes)} schemes: " + f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" + ) + + max_mutations = getattr(self.config, "maximum_mutations", 3) + + scheme = InsertionScheme() + base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} + scheme.node_inputs = self._mutate_insertion_points( + base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations + ) + + base_region_composite_points = { + (p.region_index, p.input_index) for p in base_scheme.child_region_inputs + } + scheme.child_region_inputs = self._mutate_insertion_points( + base_region_composite_points, + full_insertion_scheme.child_region_inputs, + "region composite points", + max_mutations, + ) + + base_region_output_points = { + (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs + } + scheme.region_outputs = self._mutate_insertion_points( + base_region_output_points, + full_insertion_scheme.region_outputs, + "region output points", + max_mutations, + ) + + return scheme + + def _copy_graph(self) -> gs.Graph: + """Create an independent copy of the computation graph.""" + new_graph = gs.import_onnx(self.onnx_model) + new_graph.toposort() + return new_graph + + def _get_quant_dtype(self, quant_type: str) -> np.dtype: + """Get numpy dtype for quantization type.""" + if quant_type == "fp8": + try: + return np.dtype(np.float8_e4m3fn) + except (AttributeError, TypeError): + logger.warning( + "FP8 dtype not available (requires numpy >= 2.0), " + "using uint8 as placeholder. Note: This may not produce " + "correct results without proper FP8 support." + ) + return np.uint8 + + dtype_map = { + "int8": np.int8, + "uint8": np.uint8, + } + + if quant_type not in dtype_map: + logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") + return np.int8 + + return dtype_map[quant_type] + + def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: + """Convert DQ dtype string to numpy dtype.""" + dtype_map = { + "float16": np.float16, + "float32": np.float32, + } + + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + + if dtype_str not in dtype_map: + logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") + return np.float32 + + return dtype_map[dtype_str] + + def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: + """Build mapping from tensor names to tensor objects.""" + tensor_map = {} + + for node in graph.nodes: + for output in node.outputs: + if hasattr(output, "name") and output.name: + tensor_map[output.name] = output + + for input_tensor in graph.inputs: + if hasattr(input_tensor, "name") and input_tensor.name: + tensor_map[input_tensor.name] = input_tensor + + for node in graph.nodes: + for input_tensor in node.inputs: + if ( + isinstance(input_tensor, gs.Constant) + and hasattr(input_tensor, "name") + and input_tensor.name + ): + tensor_map[input_tensor.name] = input_tensor + + return tensor_map + + def _get_tensor_metadata( + self, tensor: gs.Tensor, is_constant: bool + ) -> tuple[tuple | None, np.dtype]: + """Extract shape and dtype metadata from a tensor.""" + default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + + if is_constant and hasattr(tensor, "values") and tensor.values is not None: + return tensor.values.shape, tensor.values.dtype + elif hasattr(tensor, "shape"): + dtype = ( + tensor.dtype + if hasattr(tensor, "dtype") and tensor.dtype is not None + else default_dtype + ) + return tensor.shape, dtype + return None, default_dtype + + def _fix_zero_point_initializers(self, model: onnx.ModelProto) -> None: + """Fix INT8 zero_point initializers to use int32_data instead of raw_data.""" + fixed_count = 0 + + for initializer in model.graph.initializer: + if ( + "_zp_" in initializer.name + and initializer.data_type == onnx.TensorProto.INT8 + and len(initializer.raw_data) > 0 + and len(initializer.int32_data) == 0 + ): + np_array = onnx.numpy_helper.to_array(initializer) + int32_values = np_array.astype(np.int32).flatten().tolist() + + new_tensor = onnx.helper.make_tensor( + initializer.name, + onnx.TensorProto.INT8, + list(initializer.dims), + int32_values, + ) + initializer.CopyFrom(new_tensor) + fixed_count += 1 + + if fixed_count > 0: + logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") + + def _create_qdq_nodes( + self, + tensor_name: str, + qdq_input: gs.Tensor, + output_shape: tuple | None, + output_dtype: np.dtype, + quant_dtype: np.dtype, + q_scale: float, + ) -> tuple[gs.Node, gs.Node]: + """Create QuantizeLinear and DequantizeLinear node pair.""" + # Create unique names for Q/DQ nodes + q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") + dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") + # Determine scale dtype from output_dtype (fp16/tf32/fp32) + # Scale should match the precision of the original I/O tensor + dtype_map = {"float16": np.float16, "float32": np.float32} + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + scale_dtype = dtype_map.get(np.dtype(output_dtype).name, np.float32) + + logger.debug( + f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" + ) + + q_scale_values = np.array([q_scale], dtype=scale_dtype) + q_zp_values = np.array([0], dtype=quant_dtype) + q_inputs = [ + qdq_input, + gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), + gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), + ] + q_node = gs.Node( + op="QuantizeLinear", + name=q_name, + inputs=q_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape) + ], + ) + + dq_scale_values = np.array([q_scale], dtype=scale_dtype) + dq_zp_values = np.array([0], dtype=quant_dtype) + dq_inputs = [ + q_node.outputs[0], + gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), + gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), + ] + dq_node = gs.Node( + op="DequantizeLinear", + name=dq_name, + inputs=dq_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape) + ], + ) + + return q_node, dq_node + + def _insert_qdq_at_tensors( + self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] + ) -> None: + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.""" + q_scale = self.config.default_q_scale + quant_type = self.config.default_quant_type + quant_dtype = self._get_quant_dtype(quant_type) + + logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") + + resolved_insertion_points = merge_resolved_insertion_points( + graph, resolved_insertion_points + ) + + tensor_map = self._build_tensor_map(graph) + tensor_users_map = get_tensor_consumer_node_indices(graph) + logger.debug( + f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" + ) + + for insertion_point in resolved_insertion_points: + tensor_name = insertion_point.tensor_name + node_index = insertion_point.node_index + input_index = insertion_point.input_index + + original_tensor = tensor_map[tensor_name] + if node_index is not None: + assert node_index < len(graph.nodes), "Node index out of range" + target_node = graph.nodes[node_index] + assert input_index is not None, "Input index must be set when node index is set" + assert input_index < len(target_node.inputs), ( + f"Input index out of range for node {target_node.name}" + ) + original_tensor = target_node.inputs[input_index] + assert tensor_name == original_tensor.name, ( + f"Tensor name mismatch for node {target_node.name} input {input_index}" + ) + else: + assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map" + assert input_index is None, "Input index must be None when node index is None" + + is_constant = isinstance(original_tensor, gs.Constant) + output_shape, output_dtype = self._get_tensor_metadata(original_tensor, is_constant) + + unique_suffix = "qdq" + if node_index is not None: + unique_suffix = f"n{node_index}_i{input_index}" + unique_tensor_name = f"{tensor_name}_{unique_suffix}" + + q_node, dq_node = self._create_qdq_nodes( + unique_tensor_name, + original_tensor, + output_shape, + output_dtype, + quant_dtype, + q_scale, + ) + + graph.nodes.extend([q_node, dq_node]) + + if node_index is not None: + target_node.inputs[input_index] = dq_node.outputs[0] + logger.debug( + f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " + f"({target_node.name}) input #{input_index}" + ) + else: + users = tensor_users_map[tensor_name] + for user_index in users: + user_node = graph.nodes[user_index] + for i, input_tensor in enumerate(user_node.inputs): + if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: + user_node.inputs[i] = dq_node.outputs[0] + break + logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") + + logger.debug("Running graph cleanup and topological sort") + try: + graph.cleanup().toposort() + logger.debug("Graph cleanup completed") + except Exception as e: + logger.error(f"Graph cleanup failed: {e}") + raise RuntimeError(f"Graph cleanup failed after Q/DQ insertion: {e}") from e + + +class QDQAutotuner(QDQAutotunerBase): + """Q/DQ autotuner with automatic region discovery around compute-intensive ops.""" + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuner and discover optimization regions automatically.""" + super().initialize(config, pattern_cache) + self._search_regions() + + def _visit_region_recursively(self, region: Region) -> list[Region]: + """Recursively traverse region hierarchy and collect all regions.""" + regions = [region] + + for child in region.get_children(): + regions.extend(self._visit_region_recursively(child)) + + return regions + + def _reassign_region_ids(self, regions: list[Region]) -> None: + """Reassign sequential IDs to regions in breadth-first order.""" + region_id = 0 + + queue = deque(regions) + + while queue: + region = queue.popleft() + region.id = region_id + region_id += 1 + queue.extend(region.get_children()) + + def _search_regions(self) -> None: + """Discover and organize optimization regions automatically.""" + logger.info("Discovering optimization regions") + search = CombinedRegionSearch( + self.graph, + maximum_sequence_region_size=self.config.maximum_sequence_region_size, + minimum_topdown_search_size=self.config.minimum_topdown_search_size, + ) + self.regions = search.search_regions() + + self._reassign_region_ids(self.regions) + logger.debug(f"Found {len(self.regions)} top-level regions") + + all_regions = [] + for region in self.regions: + all_regions.extend(self._visit_region_recursively(region)) + + logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") + + leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] + other_regions = [region for region in all_regions if region.type != RegionType.LEAF] + + all_regions = leaf_regions + other_regions + self.regions = all_regions + + num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) + num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) + num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) + + logger.info( + f"Discovery complete: {len(self.regions)} regions " + f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + ) + logger.debug("Regions prioritized: LEAF regions first for profiling") diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py new file mode 100644 index 000000000..fe4240047 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for QDQAutotuner class. + +Tests the main autotuner class public API. +Note: Full integration tests with TensorRT benchmarking should be in separate integration test files. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern +from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType + + +def create_simple_conv_model(): + """ + Create a simple ONNX model: Input -> Conv -> Relu -> Output. + + This is a minimal model for testing autotuner initialization. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + return model + + +class TestQDQAutotuner(unittest.TestCase): + """Test QDQAutotuner functionality.""" + + @staticmethod + def _create_test_config(): + """ + Create a reasonable config for testing. + + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, + ) + + def test_creation_with_onnx_model(self): + """Test creating autotuner with ONNX ModelProto.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + assert autotuner is not None + assert autotuner.onnx_model is not None + assert autotuner.graph is not None + + def test_creation_with_gs_graph(self): + """Test creating autotuner with GraphSurgeon graph.""" + model = create_simple_conv_model() + gs_graph = gs.import_onnx(model) + + autotuner = QDQAutotuner(gs_graph) + + assert autotuner is not None + assert autotuner.graph is not None + + def test_initialize_with_default_config(self): + """Test initialization with default test config.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should have provided config + assert autotuner.config is not None + assert autotuner.config.maximum_sequence_region_size == 50 + + # Should have discovered regions + assert len(autotuner.regions) > 0 + + def test_initialize_with_config(self): + """Test initialization with custom config (different from default).""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + # Create custom config with different values + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + minimum_topdown_search_size=5, + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=5, + maximum_mutations=5, + maximum_generation_attempts=50, + pattern_cache_minimum_distance=2, + pattern_cache_max_entries_per_pattern=16, + ) + autotuner.initialize(config) + + # Should use provided custom config values + assert autotuner.config.verbose + assert autotuner.config.default_q_scale == 0.05 + assert autotuner.config.default_q_zero_point == 128 + assert autotuner.config.default_quant_type == "fp8" + assert autotuner.config.maximum_sequence_region_size == 20 + assert autotuner.config.minimum_topdown_search_size == 5 + assert autotuner.config.top_percent_to_mutate == 0.2 + assert autotuner.config.minimum_schemes_to_mutate == 5 + assert autotuner.config.maximum_mutations == 5 + assert autotuner.config.maximum_generation_attempts == 50 + assert autotuner.config.pattern_cache_minimum_distance == 2 + assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 + + def test_initialize_with_pattern_cache(self): + """Test initialization with pattern cache.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + pattern_cache = PatternCache() + autotuner.initialize(config, pattern_cache=pattern_cache) + + assert autotuner.pattern_cache is not None + + def test_region_discovery(self): + """Test that regions are automatically discovered.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should discover at least one region + assert len(autotuner.regions) > 0 + + # Regions should be valid + for region in autotuner.regions: + assert region.get_id() is not None + assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + + def test_export_baseline_model(self): + """Test exporting baseline model without Q/DQ.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + + try: + # Export baseline without Q/DQ insertion + autotuner.export_onnx(output_path, insert_qdq=False) + # Verify file was created + assert os.path.exists(output_path) + # Verify it's a valid ONNX model + exported_model = onnx.load(output_path) + assert exported_model is not None + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_set_profile_region(self): + """Test setting a region for profiling.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + # Should set current profile region + assert autotuner.current_profile_region == region + assert autotuner.current_profile_pattern_schemes is not None + else: + self.skipTest("No regions discovered") + + def test_generate_scheme(self): + """Test generating an insertion scheme.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + # Generate a scheme + scheme_idx = autotuner.generate() + # Should return a valid index (>= 0) or -1 if no more unique schemes + assert isinstance(scheme_idx, int) + else: + self.skipTest("No regions discovered") + + def test_submit_latency(self): + """Test submitting performance measurement.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + # Submit baseline latency + autotuner.submit(10.5) + # Baseline should be recorded + assert autotuner.baseline_latency_ms == 10.5 + + def test_save_and_load_state(self): + """Test saving and loading autotuner state.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit some results + autotuner.submit(10.5) # baseline + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + state_path = f.name + + try: + # Save state + autotuner.save_state(state_path) + assert os.path.exists(state_path) + + # Create new autotuner and load state + autotuner2 = QDQAutotuner(model) + config2 = self._create_test_config() + autotuner2.initialize(config2) + autotuner2.load_state(state_path) + + # Baseline should match + assert autotuner2.baseline_latency_ms == 10.5 + finally: + if os.path.exists(state_path): + os.unlink(state_path) + + def test_regions_prioritization(self): + """Test that LEAF regions are prioritized.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Check that LEAF regions come before non-LEAF + leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF + ] + non_leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF + ] + + if leaf_indices and non_leaf_indices: + # All LEAF should come before non-LEAF + assert max(leaf_indices) < min(non_leaf_indices) + + def test_profiled_patterns_tracking(self): + """Test that profiled patterns are tracked.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + autotuner.submit(10.0) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + scheme_idx = autotuner.generate() + if scheme_idx >= 0: + autotuner.submit(12.0) + autotuner.set_profile_region(None, commit=True) + pattern_sig = RegionPattern.from_region(region, autotuner.graph).signature + profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] + assert pattern_sig in profiled_patterns + else: + self.skipTest("No regions discovered") From 034d69af12d86e2ebeb635e8431aeb8e29283e77 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 9 Feb 2026 08:36:26 +0000 Subject: [PATCH 02/14] pick back docstrings Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 261 ++++++++++++++++-- 1 file changed, 242 insertions(+), 19 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index 9eb8724dc..86074ed15 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -51,7 +51,20 @@ class QDQAutotunerBase: """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" def __init__(self, model: onnx.ModelProto | gs.Graph): - """Initialize the autotuner with an ONNX model.""" + """Initialize the autotuner with an ONNX model. + + Creates a clean copy of the model graph and initializes internal state. + After construction, call initialize() to configure the autotuner, then + use a subclass strategy to populate regions (e.g., QDQAutotuner does this + automatically during initialize()). + + Args: + model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize. + A clean copy is created internally, leaving the original unchanged. + + Raises: + TypeError: If model is neither onnx.ModelProto nor gs.Graph + """ if isinstance(model, onnx.ModelProto): self.onnx_model = model elif isinstance(model, gs.Graph): @@ -76,7 +89,22 @@ def __init__(self, model: onnx.ModelProto | gs.Graph): def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: - """Initialize autotuning session with configuration and pattern cache.""" + """Initialize autotuning session with configuration and pattern cache. + + Prepares the autotuner for profiling by setting configuration parameters + and optionally loading pattern cache data. This base method resets all profiling + state and sets up the pattern cache storage. + + Args: + config: Autotuning configuration parameters. If None, uses default Config(). + Controls Q/DQ parameters, performance thresholds, and scheme generation. + pattern_cache: Optional PatternCache object for seeding with known-good schemes. + If None, creates a new empty pattern cache for tracking best schemes. + If provided, uses existing schemes to warm-start optimization. + + Raises: + None (safe to call multiple times - will reset state each time) + """ if config is not None: self.config = config @@ -109,7 +137,24 @@ def initialize( self.initialized = True def set_profile_region(self, region: Region | None, commit: bool = True) -> None: - """Set the target region for profiling and scheme generation.""" + """Set the target region for profiling and scheme generation. + + This method manages the profiling workflow: + 1. If commit=True: Saves current schemes to profiled_patterns + 2. Creates a RegionPattern from the new region's structure + 3. For pattern-based: tries to seed schemes from pattern cache if available + 4. Sets as current for generate() and submit() calls + + Pass region=None to clear the current profile target without setting a new one. + + Args: + region: The region to profile next (None to clear current target) + commit: If True, commit current schemes to profiled_patterns + before switching. Set to False during initialization. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -185,13 +230,24 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" logger.info( - f"Profiling region {region.id} [pattern mode, level {region.level}, " - f"size {region.get_size_of_region_and_descendants()}, {mode_info}]" + f"Profiling region {region.id} [level {region.level}, size" + f"{region.get_size_of_region_and_descendants()}, {mode_info}]" ) logger.debug(f"Pattern signature: {region_pattern.signature}") def generate(self) -> int: - """Generate a new Q/DQ insertion scheme for the current pattern or region.""" + """Generate a new Q/DQ insertion scheme for the current pattern or region. + + Creates a new InsertionScheme by mutating the top-performing schemes: + 1. Checks if there are any cached schemes (error=False, latency_ms=inf) + 2. If cached schemes exist, picks one to re-profile + 3. Otherwise, generates a new scheme by mutation + 4. Selects a random scheme from the top 10 performers + 5. Mutates it by adding/removing insertion points + 6. Ensures the new scheme is unique (different from existing schemes) + 7. Adds the scheme to current_profile_pattern_schemes + + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -261,7 +317,28 @@ def generate(self) -> int: def export_onnx( self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False ) -> bytes: - """Export ONNX model with Q/DQ nodes inserted according to tested schemes.""" + """Export ONNX model with Q/DQ nodes inserted according to tested schemes. + + This method creates a modified version of the model by: + 1. For each region, finding the matching pattern + 2. Applying the best scheme for profiled patterns + 3. Applying the current scheme for the active profile pattern + 4. Resolving pattern-relative insertion points to actual tensor names + 5. Inserting Q/DQ pairs at the resolved locations + 6. Converting to FP8 if needed (always creates INT8 first, then converts) + + Args: + output_path: Optional file path where the modified ONNX model will be saved. + If None, the model is not saved to disk and only bytes are returned. + insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model + (useful for baseline measurements) + + Returns: + bytes: Serialized ONNX model as bytes + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -387,7 +464,19 @@ def export_onnx( return model_bytes def submit(self, latency_ms: float, success: bool = True) -> None: - """Submit performance measurement for the most recently generated scheme.""" + """Submit performance measurement for the most recently generated scheme. + + This method records the measured latency and manages the optimization state: + + Args: + latency_ms: Measured latency in milliseconds (must be > 0) + success: Whether the measurement succeeded. If False, sets scheme.error=True, + logs a warning, and skips speedup calculation. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + InvalidSchemeError: If no pattern or region is set, or no schemes have been generated + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -458,7 +547,19 @@ def submit(self, latency_ms: float, success: bool = True) -> None: ) def save_state(self, output_path: str) -> None: - """Save complete autotuner state to a YAML file for later reuse.""" + """Save complete autotuner state to a YAML file for later reuse. + + Serializes all optimization results including: + - Baseline latency measurement + - All profiled patterns with their signatures + - All generated schemes with insertion points and latencies + - Configuration parameters + - Current profiling state + + Args: + output_path: File path where the YAML state file will be written. + Pattern cache will be saved to _pattern_cache.yaml + """ current_pattern_sig = None if self.current_profile_pattern_schemes is not None: current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature @@ -498,7 +599,20 @@ def save_state(self, output_path: str) -> None: ) def load_state(self, input_path: str) -> None: - """Load autotuner state from a previously saved YAML file.""" + """Load autotuner state from a previously saved YAML file. + + Restores optimization results from a previous session: + 1. Matches saved patterns to current model's patterns by signature + 2. Loads all schemes with their insertion points and latencies (including unmeasured ones) + 3. Restores baseline latency and configuration + + Args: + input_path: File path to the YAML state file to load + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + FileNotFoundError: If the input_path doesn't exist + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -571,7 +685,20 @@ def load_state(self, input_path: str) -> None: logger.debug(f"No pattern cache file at {cache_path}") def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: - """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.""" + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. + + Analyzes the current model's regions against the provided quantized tensors + to extract Q/DQ insertion patterns. For each region, creates a pattern cache + entry that captures which insertion points correspond to the quantized tensors. + These cached patterns can then be used as seeds for future autotuning sessions. + + Args: + quantized_tensors: Set or list of tensor names that are quantized + (i.e., tensors that have Q/DQ nodes applied to them) + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ if not self.initialized: raise AutotunerNotInitializedError( "QDQAutotunerBase not initialized. Call initialize() first." @@ -607,7 +734,22 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No def _compute_convergence_metrics( self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None ) -> tuple[int | None, float | None]: - """Compute convergence metrics for a collection of schemes.""" + """Compute convergence metrics for a collection of schemes. + + Analyzes when the best scheme was discovered during the profiling process + by sorting schemes by their profile timestamps and finding the position + of the best scheme. + + Args: + schemes: List of insertion schemes with profile timestamps + best_scheme: The best performing scheme (lowest latency) + + Returns: + Tuple of (samples_before_best, time_to_best) where: + - samples_before_best: Number of samples tested before finding best (0-based index) + - time_to_best: Time in seconds from first sample to best sample + Both values are None if metrics cannot be computed (e.g., missing timestamps) + """ samples_before_best = None time_to_best = None @@ -690,7 +832,29 @@ def _mutate_insertion_points( return [p for p in all_points if key_fn(p) in current_points] def _generate_next_insertion_sample(self) -> InsertionScheme: - """Generate a new insertion scheme by mutating top performers.""" + """Generate a new insertion scheme by mutating top performers. + + This is the core scheme generation algorithm: + 1. Identifies top schemes by latency + 2. Randomly selects one as the base + 3. Mutates node input insertion points (add, remove, or both) + 4. Mutates region composite insertion points (child boundaries) + 5. Mutates region output insertion points + 6. Returns new unique scheme + + **Mutation Strategy:** + - Node input points: Add/remove 1-3 insertion points + - Region composite points: Add/remove 1-3 boundary points + - Region output points: Add/remove 1-3 output points + - Mutation type chosen randomly: 'add', 'remove', or 'both' + + **Baseline Case:** + If no schemes exist yet, returns an empty baseline scheme. + + Returns: + New InsertionScheme with mutated insertion points. + Returns empty scheme if no region is set or no candidates exist. + """ if self.current_profile_region is None: return InsertionScheme() @@ -891,7 +1055,20 @@ def _create_qdq_nodes( quant_dtype: np.dtype, q_scale: float, ) -> tuple[gs.Node, gs.Node]: - """Create QuantizeLinear and DequantizeLinear node pair.""" + """Create QuantizeLinear and DequantizeLinear node pair. + + Args: + tensor_name: Name of the tensor being quantized + qdq_input: Input tensor to the Q node + output_shape: Shape for Q/DQ outputs (may be None) + output_dtype: Dtype for DQ output (also used for scale dtype) + quant_dtype: Dtype for quantized values + quant_type: Quantization type string + q_scale: Quantization scale + + Returns: + Tuple of (q_node, dq_node) + """ # Create unique names for Q/DQ nodes q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") @@ -943,7 +1120,17 @@ def _create_qdq_nodes( def _insert_qdq_at_tensors( self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] ) -> None: - """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.""" + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations. + + This is the main entry point for Q/DQ insertion. It: + 1. Builds tensor map and tensor-to-users map for efficient lookup + 2. Processes each resolved insertion point to insert Q/DQ nodes + 3. Handles two insertion modes based on node_index + + Args: + graph: Graph to modify in-place + resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ + """ q_scale = self.config.default_q_scale quant_type = self.config.default_quant_type quant_dtype = self._get_quant_dtype(quant_type) @@ -1031,12 +1218,28 @@ class QDQAutotuner(QDQAutotunerBase): def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: - """Initialize autotuner and discover optimization regions automatically.""" + """Initialize autotuner and discover optimization regions automatically. + + Extends base class initialization by automatically searching for regions + after configuration is set up. Regions are discovered using pattern-based + search around compute-intensive operations. + """ super().initialize(config, pattern_cache) self._search_regions() def _visit_region_recursively(self, region: Region) -> list[Region]: - """Recursively traverse region hierarchy and collect all regions.""" + """Recursively traverse region hierarchy and collect all regions. + + Performs depth-first traversal of the region tree starting from a given + region. Collects the root region and all descendant regions (children, + grandchildren, etc.) into a flat list. + + Args: + region: Root region to start traversal from + + Returns: + List of all regions in the subtree (including root), in pre-order DFS. + """ regions = [region] for child in region.get_children(): @@ -1045,7 +1248,15 @@ def _visit_region_recursively(self, region: Region) -> list[Region]: return regions def _reassign_region_ids(self, regions: list[Region]) -> None: - """Reassign sequential IDs to regions in breadth-first order.""" + """Reassign sequential IDs to regions in breadth-first order. + + Traverses the region hierarchy (including children) and assigns new + sequential IDs starting from 0. This ensures clean, predictable region + numbering after region discovery and manipulation. + + Args: + regions: List of top-level regions (children will be processed too) + """ region_id = 0 queue = deque(regions) @@ -1057,7 +1268,19 @@ def _reassign_region_ids(self, regions: list[Region]) -> None: queue.extend(region.get_children()) def _search_regions(self) -> None: - """Discover and organize optimization regions automatically.""" + """Discover and organize optimization regions automatically. + + This is the core region discovery method that: + 1. Runs automatic region search to find optimization targets + 2. Flattens hierarchical structure into a list + 3. Prioritizes LEAF regions (contain actual nodes) + 4. Reassigns IDs for clean indexing + + **Search Strategy:** + Uses CombinedRegionSearch which performs: + - Phase 1: Bottom-up partitioning based on divergence/convergence + - Phase 2: Top-down refinement creating hierarchical structure + """ logger.info("Discovering optimization regions") search = CombinedRegionSearch( self.graph, From dac6a8415a83ae6479347f69da2c1368c0fa39a9 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 10 Feb 2026 03:13:53 +0000 Subject: [PATCH 03/14] resolve comments Signed-off-by: Will Guo --- .../unit/onnx/quantization/autotune/autotune/test_autotuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py index fe4240047..17a8dd4cc 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From 21120c37b276b0630a3f916c0b9e6239cf3128ac Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 11 Feb 2026 13:28:18 +0000 Subject: [PATCH 04/14] resolve comments Signed-off-by: Will Guo --- .../quantization/autotune/autotune/models.py | 47 ++++ .../autotune/autotune/test_autotuner.py | 204 +++++++----------- 2 files changed, 126 insertions(+), 125 deletions(-) create mode 100644 tests/unit/onnx/quantization/autotune/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/autotune/models.py b/tests/unit/onnx/quantization/autotune/autotune/models.py new file mode 100644 index 000000000..4090cfef3 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/autotune/models.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shared test ONNX models for autotuner unit tests. + +Model creation functions live here; tests import and call them directly. +""" + +import onnx +from onnx import helper + + +def _create_simple_conv_onnx_model(): + """Build ONNX model: Input -> Conv -> Relu -> Output (minimal for autotuner tests).""" + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + return helper.make_model(graph, producer_name="test") diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py index 17a8dd4cc..ef49e53b5 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py @@ -23,118 +23,82 @@ import os import sys import tempfile -import unittest -# Add parent directory to path -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# Add parent and current directory to path +_test_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.dirname(_test_dir)) +sys.path.insert(0, _test_dir) +import models as _test_models import onnx import onnx_graphsurgeon as gs -from onnx import helper +import pytest from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType -def create_simple_conv_model(): - """ - Create a simple ONNX model: Input -> Conv -> Relu -> Output. - - This is a minimal model for testing autotuner initialization. - """ - # Input - input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) +@pytest.fixture +def simple_conv_model(): + """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py.""" + return _test_models._create_simple_conv_onnx_model() - # Output - output_tensor = helper.make_tensor_value_info( - "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] - ) - # Conv node - conv_node = helper.make_node( - "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" - ) +def _create_test_config(): + """ + Create a reasonable config for testing. - # Relu node - relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") - - # Create graph - graph = helper.make_graph( - [conv_node, relu_node], - "simple_conv", - [input_tensor], - [output_tensor], - initializer=[ - helper.make_tensor( - "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) - ) - ], + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, ) - # Create model - model = helper.make_model(graph, producer_name="test") - return model - -class TestQDQAutotuner(unittest.TestCase): +class TestQDQAutotuner: """Test QDQAutotuner functionality.""" - @staticmethod - def _create_test_config(): - """ - Create a reasonable config for testing. - - Uses sensible defaults suitable for unit tests: - - verbose=False: Keep test output clean - - maximum_sequence_region_size=50: Allow larger test regions - - Other parameters: Match Config defaults for typical behavior - """ - return Config( - # Logging - verbose=False, - # Performance Requirements - # Quantization Parameters - default_q_scale=0.1, - default_q_zero_point=0, - default_quant_type="int8", - # Region Builder Settings - maximum_sequence_region_size=50, - minimum_topdown_search_size=10, - # Scheme Generation Settings - top_percent_to_mutate=0.1, - minimum_schemes_to_mutate=10, - maximum_mutations=3, - maximum_generation_attempts=100, - # Pattern Cache Settings - pattern_cache_minimum_distance=4, - pattern_cache_max_entries_per_pattern=32, - ) - - def test_creation_with_onnx_model(self): + def test_creation_with_onnx_model(self, simple_conv_model): """Test creating autotuner with ONNX ModelProto.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) assert autotuner is not None assert autotuner.onnx_model is not None assert autotuner.graph is not None - def test_creation_with_gs_graph(self): + def test_creation_with_gs_graph(self, simple_conv_model): """Test creating autotuner with GraphSurgeon graph.""" - model = create_simple_conv_model() - gs_graph = gs.import_onnx(model) - + gs_graph = gs.import_onnx(simple_conv_model) autotuner = QDQAutotuner(gs_graph) assert autotuner is not None assert autotuner.graph is not None - def test_initialize_with_default_config(self): + def test_initialize_with_default_config(self, simple_conv_model): """Test initialization with default test config.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() autotuner.initialize(config) # Should have provided config @@ -144,10 +108,9 @@ def test_initialize_with_default_config(self): # Should have discovered regions assert len(autotuner.regions) > 0 - def test_initialize_with_config(self): + def test_initialize_with_config(self, simple_conv_model): """Test initialization with custom config (different from default).""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) # Create custom config with different values config = Config( @@ -180,23 +143,21 @@ def test_initialize_with_config(self): assert autotuner.config.pattern_cache_minimum_distance == 2 assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 - def test_initialize_with_pattern_cache(self): + def test_initialize_with_pattern_cache(self, simple_conv_model): """Test initialization with pattern cache.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() pattern_cache = PatternCache() autotuner.initialize(config, pattern_cache=pattern_cache) assert autotuner.pattern_cache is not None - def test_region_discovery(self): + def test_region_discovery(self, simple_conv_model): """Test that regions are automatically discovered.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) + autotuner = QDQAutotuner(simple_conv_model) - config = self._create_test_config() + config = _create_test_config() autotuner.initialize(config) # Should discover at least one region @@ -207,11 +168,10 @@ def test_region_discovery(self): assert region.get_id() is not None assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] - def test_export_baseline_model(self): + def test_export_baseline_model(self, simple_conv_model): """Test exporting baseline model without Q/DQ.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: @@ -229,11 +189,10 @@ def test_export_baseline_model(self): if os.path.exists(output_path): os.unlink(output_path) - def test_set_profile_region(self): + def test_set_profile_region(self, simple_conv_model): """Test setting a region for profiling.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) if len(autotuner.regions) > 0: @@ -243,13 +202,12 @@ def test_set_profile_region(self): assert autotuner.current_profile_region == region assert autotuner.current_profile_pattern_schemes is not None else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") - def test_generate_scheme(self): + def test_generate_scheme(self, simple_conv_model): """Test generating an insertion scheme.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) if len(autotuner.regions) > 0: @@ -260,24 +218,22 @@ def test_generate_scheme(self): # Should return a valid index (>= 0) or -1 if no more unique schemes assert isinstance(scheme_idx, int) else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") - def test_submit_latency(self): + def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Submit baseline latency autotuner.submit(10.5) # Baseline should be recorded assert autotuner.baseline_latency_ms == 10.5 - def test_save_and_load_state(self): + def test_save_and_load_state(self, simple_conv_model): """Test saving and loading autotuner state.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Submit some results @@ -292,8 +248,8 @@ def test_save_and_load_state(self): assert os.path.exists(state_path) # Create new autotuner and load state - autotuner2 = QDQAutotuner(model) - config2 = self._create_test_config() + autotuner2 = QDQAutotuner(simple_conv_model) + config2 = _create_test_config() autotuner2.initialize(config2) autotuner2.load_state(state_path) @@ -303,11 +259,10 @@ def test_save_and_load_state(self): if os.path.exists(state_path): os.unlink(state_path) - def test_regions_prioritization(self): + def test_regions_prioritization(self, simple_conv_model): """Test that LEAF regions are prioritized.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) # Check that LEAF regions come before non-LEAF @@ -322,11 +277,10 @@ def test_regions_prioritization(self): # All LEAF should come before non-LEAF assert max(leaf_indices) < min(non_leaf_indices) - def test_profiled_patterns_tracking(self): + def test_profiled_patterns_tracking(self, simple_conv_model): """Test that profiled patterns are tracked.""" - model = create_simple_conv_model() - autotuner = QDQAutotuner(model) - config = self._create_test_config() + autotuner = QDQAutotuner(simple_conv_model) + config = _create_test_config() autotuner.initialize(config) autotuner.submit(10.0) @@ -342,4 +296,4 @@ def test_profiled_patterns_tracking(self): profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] assert pattern_sig in profiled_patterns else: - self.skipTest("No regions discovered") + pytest.skip("No regions discovered") From 85fe30f563754ede389930302c2d37d49fefa89c Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 11 Feb 2026 13:43:28 +0000 Subject: [PATCH 05/14] resolve comments Signed-off-by: Will Guo --- .../autotune/{autotune => }/models.py | 0 .../autotune/{autotune => }/test_autotuner.py | 49 +++++++++++++------ 2 files changed, 35 insertions(+), 14 deletions(-) rename tests/unit/onnx/quantization/autotune/{autotune => }/models.py (100%) rename tests/unit/onnx/quantization/autotune/{autotune => }/test_autotuner.py (88%) diff --git a/tests/unit/onnx/quantization/autotune/autotune/models.py b/tests/unit/onnx/quantization/autotune/models.py similarity index 100% rename from tests/unit/onnx/quantization/autotune/autotune/models.py rename to tests/unit/onnx/quantization/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py similarity index 88% rename from tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py rename to tests/unit/onnx/quantization/autotune/test_autotuner.py index ef49e53b5..f6a920bec 100644 --- a/tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -21,14 +21,8 @@ """ import os -import sys import tempfile -# Add parent and current directory to path -_test_dir = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, os.path.dirname(_test_dir)) -sys.path.insert(0, _test_dir) - import models as _test_models import onnx import onnx_graphsurgeon as gs @@ -205,20 +199,47 @@ def test_set_profile_region(self, simple_conv_model): pytest.skip("No regions discovered") def test_generate_scheme(self, simple_conv_model): - """Test generating an insertion scheme.""" + """Test generating multiple schemes and that Q/DQ nodes appear in exported model.""" autotuner = QDQAutotuner(simple_conv_model) config = _create_test_config() autotuner.initialize(config) - if len(autotuner.regions) > 0: - region = autotuner.regions[0] - autotuner.set_profile_region(region) - # Generate a scheme + if len(autotuner.regions) == 0: + pytest.skip("No regions discovered") + + autotuner.submit(10.0) # baseline + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate multiple schemes and submit a latency for each + num_generated = 0 + while True: scheme_idx = autotuner.generate() - # Should return a valid index (>= 0) or -1 if no more unique schemes + if scheme_idx < 0: + break assert isinstance(scheme_idx, int) - else: - pytest.skip("No regions discovered") + autotuner.submit(10.0 + num_generated * 0.1) # dummy latency + num_generated += 1 + if num_generated >= 5: # cap iterations + break + + assert num_generated > 0, "Expected at least one scheme to be generated" + autotuner.set_profile_region(None, commit=True) + + # Export with Q/DQ and verify Q/DQ nodes are in the model + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + try: + autotuner.export_onnx(output_path, insert_qdq=True) + exported = onnx.load(output_path) + node_ops = [n.op_type for n in exported.graph.node] + assert "QuantizeLinear" in node_ops, "Expected QuantizeLinear nodes in exported model" + assert "DequantizeLinear" in node_ops, ( + "Expected DequantizeLinear nodes in exported model" + ) + finally: + if os.path.exists(output_path): + os.unlink(output_path) def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" From 4e5f1671915ae6165f9e3548d0b1d308d1940ef7 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 16 Feb 2026 23:46:36 +0000 Subject: [PATCH 06/14] resolve comments Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 432 +++++++++--------- 1 file changed, 211 insertions(+), 221 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index 86074ed15..e633f30a1 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -16,9 +16,10 @@ """Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" import copy +import functools import os import random -from collections import deque +from collections import Counter, deque from datetime import datetime, timezone import numpy as np @@ -46,10 +47,45 @@ from modelopt.onnx.quantization.fp8 import int8_to_fp8 from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices +_MUTATION_SPECS = [ + ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), + ( + "region_composite_inputs", + "region composite points", + lambda p: (p.region_index, p.input_index), + ), + ( + "region_output_points", + "region output points", + lambda p: (p.region_index, p.node_index, p.output_index), + ), +] + + +def _requires_init(method): + """Decorator that raises AutotunerNotInitializedError if initialize() has not been called.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + return method(self, *args, **kwargs) + + return wrapper + class QDQAutotunerBase: """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + _DTYPE_MAP = { + "int8": np.int8, + "uint8": np.uint8, + "float16": np.float16, + "float32": np.float32, + } + def __init__(self, model: onnx.ModelProto | gs.Graph): """Initialize the autotuner with an ONNX model. @@ -86,6 +122,8 @@ def __init__(self, model: onnx.ModelProto | gs.Graph): logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + requires_init = _requires_init + def initialize( self, config: Config | None = None, pattern_cache: PatternCache | None = None ) -> None: @@ -136,6 +174,54 @@ def initialize( self.initialized = True + def _commit_current_pattern(self, save: bool = True) -> None: + """Save current pattern schemes to profiled_patterns (if save) and clear current state.""" + if save and self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + def _seed_from_cache(self, pattern: RegionPattern) -> tuple[PatternSchemes | None, int]: + """Seed PatternSchemes from pattern cache for the given pattern. Returns (schemes, num_seeded).""" + if self.pattern_cache is None: + return None, 0 + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is None or len(cache_schemes.schemes) == 0: + logger.debug("No pattern cache entries for this region") + return None, 0 + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = pattern + num_seeded = 0 + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + return pattern_schemes, num_seeded + + @_requires_init def set_profile_region(self, region: Region | None, commit: bool = True) -> None: """Set the target region for profiling and scheme generation. @@ -155,37 +241,8 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - - if commit: - if self.current_profile_pattern_schemes is not None: - num_schemes = len(self.current_profile_pattern_schemes.schemes) - best_scheme = self.current_profile_pattern_schemes.best_scheme - best_latency = best_scheme.latency_ms if best_scheme else float("inf") - - samples_before_best, time_to_best = self._compute_convergence_metrics( - self.current_profile_pattern_schemes.schemes, best_scheme - ) - - logger.info( - f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" - ) - logger.debug( - f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" - ) - if samples_before_best is not None: - logger.debug(f"Convergence: best found at sample {samples_before_best}") - if time_to_best is not None: - logger.debug(f"Time to best: {time_to_best:.2f}s") - self.profiled_patterns.append(self.current_profile_pattern_schemes) - if commit or region is None: - self.current_profile_region = None - self.current_profile_pattern_schemes = None - self.current_insertion_scheme_index = None + self._commit_current_pattern(save=commit) if region is None: return @@ -199,27 +256,7 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None logger.debug(f"Pattern signature: {region_pattern.signature}") return - pattern_schemes = None - num_seeded = 0 - - if self.pattern_cache is not None: - cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) - - if cache_schemes is not None and len(cache_schemes.schemes) > 0: - pattern_schemes = PatternSchemes() - pattern_schemes.pattern = region_pattern - - for cached_scheme in cache_schemes.schemes: - scheme_copy = copy.deepcopy(cached_scheme) - scheme_copy.latency_ms = float("inf") - scheme_copy.error = False - pattern_schemes.schemes.append(scheme_copy) - num_seeded += 1 - - logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") - else: - logger.debug("No pattern cache entries for this region") - + pattern_schemes, num_seeded = self._seed_from_cache(region_pattern) if pattern_schemes is None: pattern_schemes = PatternSchemes() pattern_schemes.pattern = region_pattern @@ -235,6 +272,7 @@ def set_profile_region(self, region: Region | None, commit: bool = True) -> None ) logger.debug(f"Pattern signature: {region_pattern.signature}") + @_requires_init def generate(self) -> int: """Generate a new Q/DQ insertion scheme for the current pattern or region. @@ -248,11 +286,7 @@ def generate(self) -> int: 7. Adds the scheme to current_profile_pattern_schemes """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - elif self.current_profile_pattern_schemes is None: + if self.current_profile_pattern_schemes is None: raise InvalidSchemeError("No region selected. Call set_profile_region() first.") pattern_schemes = self.current_profile_pattern_schemes @@ -314,6 +348,77 @@ def generate(self) -> int: logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") return -1 + def _resolve_scheme_for_region( + self, region: Region, best: bool + ) -> tuple[InsertionScheme | None, RegionPattern]: + """Resolve the insertion scheme to use for a region from profiled/current/cache. + + Args: + region: The region to resolve the scheme for + best: If True, return the best scheme for the region + + Returns: + tuple[InsertionScheme | None, RegionPattern]: The scheme and pattern for the region + """ + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + pattern_schemes = self.current_profile_pattern_schemes + if pattern_schemes is None or pattern != pattern_schemes.pattern: + pass + elif best: + current_scheme = pattern_schemes.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + assert scheme_index < len(pattern_schemes.schemes), ( + f"Invalid scheme index: {scheme_index}" + ) + current_scheme = pattern_schemes.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is not None: + schemes = cache_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + + return current_scheme, pattern + + def _exclude_overlapping_insertion_points( + self, + resolved_insertion_points: set[ResolvedInsertionPoint], + region: Region, + pattern: RegionPattern, + ) -> None: + """Remove this region's full insertion points from resolved set so they can be replaced.""" + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) + resolved_insertion_points.difference_update(all_region_ips) + if all_region_ips: + logger.debug(f" → Excluded {len(all_region_ips)} overlapping insertion points") + + @_requires_init def export_onnx( self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False ) -> bytes: @@ -339,11 +444,6 @@ def export_onnx( Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - output_desc = output_path if output_path is not None else "" original_quant_type = self.config.default_quant_type needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" @@ -364,58 +464,13 @@ def export_onnx( logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") for region in self.regions: - pattern = RegionPattern.from_region(region, self.graph) - logger.debug(f"Region {region.id} (level {region.level})") - logger.debug(f" → Pattern signature: {pattern.signature}") - - matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) - current_scheme = matched.best_scheme if matched else None - - if matched: - if current_scheme: - logger.debug( - f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" - ) - else: - logger.debug(" → Matched profiled pattern but no valid schemes") - - if current_scheme is None: - current_scheme = self.current_profile_pattern_schemes - if current_scheme is None or pattern != current_scheme.pattern: - pass - elif best: - current_scheme = current_scheme.best_scheme - else: - scheme_index = self.current_insertion_scheme_index - if scheme_index is not None: - assert scheme_index < len(current_scheme.schemes), ( - f"Invalid scheme index: {scheme_index}" - ) - current_scheme = current_scheme.schemes[scheme_index] - logger.debug(f" → Using current pattern scheme #{scheme_index}") - - if current_scheme is None and self.pattern_cache is not None: - pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) - if pattern_schemes is not None: - schemes = pattern_schemes.schemes - if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: - current_scheme = schemes[0] - logger.debug(" → Using imported pattern from cache") - + current_scheme, pattern = self._resolve_scheme_for_region(region, best) if current_scheme is None: - logger.debug(" → No scheme available, skipping") continue - full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) - assert full_insertion_scheme is not None - all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) - assert isinstance(all_region_ips, set) - resolved_insertion_points.difference_update(all_region_ips) - excluded_tensors = all_region_ips - resolved_insertion_points - if excluded_tensors: - logger.debug( - f" → Excluded {len(excluded_tensors)} overlapping insertion points" - ) + self._exclude_overlapping_insertion_points( + resolved_insertion_points, region, pattern + ) new_ips = pattern.matches(region, self.graph, current_scheme) if new_ips: @@ -463,6 +518,7 @@ def export_onnx( ) return model_bytes + @_requires_init def submit(self, latency_ms: float, success: bool = True) -> None: """Submit performance measurement for the most recently generated scheme. @@ -477,11 +533,6 @@ def submit(self, latency_ms: float, success: bool = True) -> None: AutotunerNotInitializedError: If initialize() hasn't been called InvalidSchemeError: If no pattern or region is set, or no schemes have been generated """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - if self.baseline_latency_ms is None: self.baseline_latency_ms = latency_ms logger.info(f"Baseline latency: {latency_ms:.3f} ms") @@ -598,6 +649,7 @@ def save_state(self, output_path: str) -> None: f"{self.pattern_cache.total_schemes} schemes" ) + @_requires_init def load_state(self, input_path: str) -> None: """Load autotuner state from a previously saved YAML file. @@ -613,11 +665,6 @@ def load_state(self, input_path: str) -> None: AutotunerNotInitializedError: If initialize() hasn't been called FileNotFoundError: If the input_path doesn't exist """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - with open(input_path) as f: state = yaml.safe_load(f) @@ -684,6 +731,7 @@ def load_state(self, input_path: str) -> None: else: logger.debug(f"No pattern cache file at {cache_path}") + @_requires_init def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. @@ -699,11 +747,6 @@ def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> No Raises: AutotunerNotInitializedError: If initialize() hasn't been called """ - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - if isinstance(quantized_tensors, list): quantized_tensors = set(quantized_tensors) @@ -782,14 +825,12 @@ def _compute_convergence_metrics( def _is_region_profiled(self, region: Region) -> bool: """Check if a region's pattern has already been fully profiled.""" - - def match_pattern(pattern: PatternSchemes, region: Region) -> bool: - """Check if a pattern matches a region.""" - if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): - return False - return not any(not scheme.is_profiled for scheme in pattern.schemes) - - return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + return any( + p.pattern is not None + and p.pattern.matches(region, self.graph) + and all(s.is_profiled for s in p.schemes) + for p in self.profiled_patterns + ) def _mutate_insertion_points( self, base_points, all_points, point_type: str, max_mutations: int @@ -904,32 +945,20 @@ def _generate_next_insertion_sample(self) -> InsertionScheme: ) max_mutations = getattr(self.config, "maximum_mutations", 3) - scheme = InsertionScheme() - base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} - scheme.node_inputs = self._mutate_insertion_points( - base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations - ) - - base_region_composite_points = { - (p.region_index, p.input_index) for p in base_scheme.child_region_inputs - } - scheme.child_region_inputs = self._mutate_insertion_points( - base_region_composite_points, - full_insertion_scheme.child_region_inputs, - "region composite points", - max_mutations, - ) - base_region_output_points = { - (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs - } - scheme.region_outputs = self._mutate_insertion_points( - base_region_output_points, - full_insertion_scheme.region_outputs, - "region output points", - max_mutations, - ) + for attr, point_type, key_fn in _MUTATION_SPECS: + base_points = {key_fn(p) for p in getattr(base_scheme, attr)} + setattr( + scheme, + attr, + self._mutate_insertion_points( + base_points, + getattr(full_insertion_scheme, attr), + point_type, + max_mutations, + ), + ) return scheme @@ -939,9 +968,9 @@ def _copy_graph(self) -> gs.Graph: new_graph.toposort() return new_graph - def _get_quant_dtype(self, quant_type: str) -> np.dtype: - """Get numpy dtype for quantization type.""" - if quant_type == "fp8": + def _resolve_dtype(self, dtype_str: str, default: np.dtype = np.int8) -> np.dtype: + """Resolve a dtype string (quant or DQ output) to a numpy dtype.""" + if dtype_str == "fp8": try: return np.dtype(np.float8_e4m3fn) except (AttributeError, TypeError): @@ -951,63 +980,30 @@ def _get_quant_dtype(self, quant_type: str) -> np.dtype: "correct results without proper FP8 support." ) return np.uint8 - - dtype_map = { - "int8": np.int8, - "uint8": np.uint8, - } - - if quant_type not in dtype_map: - logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") - return np.int8 - - return dtype_map[quant_type] - - def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: - """Convert DQ dtype string to numpy dtype.""" - dtype_map = { - "float16": np.float16, - "float32": np.float32, - } - - if hasattr(np, "bfloat16"): - dtype_map["bfloat16"] = np.bfloat16 - - if dtype_str not in dtype_map: - logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") - return np.float32 - - return dtype_map[dtype_str] + if hasattr(np, "bfloat16") and dtype_str == "bfloat16": + return np.bfloat16 + if dtype_str in self._DTYPE_MAP: + return self._DTYPE_MAP[dtype_str] + logger.warning(f"Unknown dtype '{dtype_str}', using default {default}") + return default def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: """Build mapping from tensor names to tensor objects.""" - tensor_map = {} - + tensor_map = {t.name: t for t in graph.inputs if hasattr(t, "name") and t.name} for node in graph.nodes: - for output in node.outputs: - if hasattr(output, "name") and output.name: - tensor_map[output.name] = output - - for input_tensor in graph.inputs: - if hasattr(input_tensor, "name") and input_tensor.name: - tensor_map[input_tensor.name] = input_tensor - - for node in graph.nodes: - for input_tensor in node.inputs: - if ( - isinstance(input_tensor, gs.Constant) - and hasattr(input_tensor, "name") - and input_tensor.name - ): - tensor_map[input_tensor.name] = input_tensor - + for t in node.inputs: + if hasattr(t, "name") and t.name: + tensor_map[t.name] = t + for t in node.outputs: + if isinstance(t, gs.Constant) and hasattr(t, "name") and t.name: + tensor_map[t.name] = t return tensor_map def _get_tensor_metadata( self, tensor: gs.Tensor, is_constant: bool ) -> tuple[tuple | None, np.dtype]: """Extract shape and dtype metadata from a tensor.""" - default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + default_dtype = self._resolve_dtype(self.config.default_dq_dtype, np.float32) if is_constant and hasattr(tensor, "values") and tensor.values is not None: return tensor.values.shape, tensor.values.dtype @@ -1133,7 +1129,7 @@ def _insert_qdq_at_tensors( """ q_scale = self.config.default_q_scale quant_type = self.config.default_quant_type - quant_dtype = self._get_quant_dtype(quant_type) + quant_dtype = self._resolve_dtype(quant_type, np.int8) logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") @@ -1227,6 +1223,7 @@ def initialize( super().initialize(config, pattern_cache) self._search_regions() + @staticmethod def _visit_region_recursively(self, region: Region) -> list[Region]: """Recursively traverse region hierarchy and collect all regions. @@ -1288,28 +1285,21 @@ def _search_regions(self) -> None: minimum_topdown_search_size=self.config.minimum_topdown_search_size, ) self.regions = search.search_regions() - self._reassign_region_ids(self.regions) logger.debug(f"Found {len(self.regions)} top-level regions") + # Flatten the hierarchy into a list of all regions all_regions = [] for region in self.regions: all_regions.extend(self._visit_region_recursively(region)) - logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") - - leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] - other_regions = [region for region in all_regions if region.type != RegionType.LEAF] - - all_regions = leaf_regions + other_regions + all_regions.sort(key=lambda r: r.type != RegionType.LEAF) self.regions = all_regions - num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) - num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) - num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) - + type_counts = Counter(r.type for r in self.regions) logger.info( f"Discovery complete: {len(self.regions)} regions " - f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + f"({type_counts[RegionType.LEAF]} LEAF, {type_counts[RegionType.COMPOSITE]} COMPOSITE, " + f"{type_counts[RegionType.ROOT]} ROOT)" ) logger.debug("Regions prioritized: LEAF regions first for profiling") From 4ad510bfd276d9bb01b483c034a80bca687f5564 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 23 Feb 2026 02:16:20 +0000 Subject: [PATCH 07/14] fix test failures Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 10 +- modelopt/onnx/quantization/autotune/common.py | 547 +++++++++++++++++- .../quantization/autotune/test_autotuner.py | 60 +- 3 files changed, 576 insertions(+), 41 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index e633f30a1..f4f1adc63 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -50,12 +50,12 @@ _MUTATION_SPECS = [ ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), ( - "region_composite_inputs", + "child_region_inputs", "region composite points", lambda p: (p.region_index, p.input_index), ), ( - "region_output_points", + "region_outputs", "region output points", lambda p: (p.region_index, p.node_index, p.output_index), ), @@ -1224,7 +1224,7 @@ def initialize( self._search_regions() @staticmethod - def _visit_region_recursively(self, region: Region) -> list[Region]: + def _visit_region_recursively(region: Region) -> list[Region]: """Recursively traverse region hierarchy and collect all regions. Performs depth-first traversal of the region tree starting from a given @@ -1240,7 +1240,7 @@ def _visit_region_recursively(self, region: Region) -> list[Region]: regions = [region] for child in region.get_children(): - regions.extend(self._visit_region_recursively(child)) + regions.extend(QDQAutotuner._visit_region_recursively(child)) return regions @@ -1291,7 +1291,7 @@ def _search_regions(self) -> None: # Flatten the hierarchy into a list of all regions all_regions = [] for region in self.regions: - all_regions.extend(self._visit_region_recursively(region)) + all_regions.extend(QDQAutotuner._visit_region_recursively(region)) all_regions.sort(key=lambda r: r.type != RegionType.LEAF) self.regions = all_regions diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index a8929315a..922ab09eb 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -18,15 +18,22 @@ import hashlib from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any, Optional + +import onnx_graphsurgeon as gs +import yaml from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, + ResolvedInsertionPoint, ) +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + class AutotunerError(Exception): """Base exception for autotuner-related errors.""" @@ -315,3 +322,541 @@ def __str__(self) -> str: f"region_output_insertions={len(self.region_outputs)}, " f"latency={self.latency_ms:.3f}ms{error_str})" ) + + +@dataclass +class PatternSchemes: + """Collection of Q/DQ insertion schemes for a single pattern. + + Manages multiple InsertionScheme candidates for a region pattern, tracking + their performance and identifying the best-performing configuration. This + enables pattern-based optimization where all regions with the same structure + use the same Q/DQ insertion strategy. + + **Workflow:** + 1. Pattern is identified from region structure + 2. Multiple schemes are generated and tested + 3. Each scheme is measured (latency_ms) + 4. Best scheme is selected (lowest latency) + 5. Best scheme is applied to all matching regions + + **Best Scheme Selection:** + - Automatically identifies scheme with lowest latency + - Excludes schemes with errors (error=True) + - Schemes with latency_ms = inf are considered unmeasured + - best_scheme property provides easy access to optimal configuration + + **Attributes:** + pattern: RegionPattern defining the structural signature + schemes: List of InsertionScheme candidates with measurements + """ + + pattern: Optional["RegionPattern"] = None # Structural pattern signature + schemes: list[InsertionScheme] = field(default_factory=list) # Candidate schemes + + @property + def pattern_signature(self) -> str: + """Get the pattern signature string.""" + return self.pattern.signature if self.pattern else "" + + @property + def pattern_size(self) -> int: + """Get the pattern size (total node count).""" + return self.pattern.size if self.pattern else 0 + + @property + def best_scheme(self) -> InsertionScheme | None: + """Get the best performing scheme (lowest latency). + + Scans all schemes to find the one with minimum latency_ms, + excluding schemes with errors. + + Returns: + InsertionScheme with lowest latency (excluding error schemes), + or None if no valid schemes exist + """ + if len(self.schemes) == 0: + return None + min_idx, min_latency = -1, float("inf") + for idx, scheme in enumerate(self.schemes): + if not scheme.error and scheme.latency_ms < min_latency: + min_idx = idx + min_latency = scheme.latency_ms + if min_idx < 0: + return None + return self.schemes[min_idx] + + @property + def num_schemes(self) -> int: + """Get total number of schemes.""" + return len(self.schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Note: Excludes runtime objects like pattern (RegionPattern). + Only serializes metadata and schemes. + """ + return { + "pattern_signature": self.pattern_signature, + "pattern_size": self.pattern_size, + "schemes": [scheme.to_dict() for scheme in self.schemes], + } + + @classmethod + def from_dict( + cls, data: dict[str, Any], pattern: Optional["RegionPattern"] = None + ) -> "PatternSchemes": + """Create PatternSchemes from serialized dictionary. + + Reconstructs the pattern schemes collection from saved data. The + RegionPattern object must be provided separately since it's not + serialized (it's a runtime object computed from the graph). + + If no pattern is provided, creates a minimal RegionPattern from the + saved signature and size for signature matching purposes. + + Args: + data: Dictionary containing 'pattern_signature', 'pattern_size', + and 'schemes' keys + pattern: RegionPattern object to associate (must match signature). + If None, creates minimal pattern from saved data. + + Returns: + Reconstructed PatternSchemes instance + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + ps = cls() + + # If no pattern provided, create minimal one from saved data + if pattern is None and "pattern_signature" in data: + pattern = RegionPattern( + signature=data["pattern_signature"], size=data.get("pattern_size", 0) + ) + + ps.pattern = pattern + + ps.schemes = [ + InsertionScheme.from_dict(scheme_data) for scheme_data in data.get("schemes", []) + ] + + return ps + + def __str__(self) -> str: + """String representation for debugging.""" + best_latency = self.best_scheme.latency_ms if self.best_scheme else 0.0 + return ( + f"PatternSchemes(pattern='{self.pattern_signature[:40]}...', " + f"schemes={self.num_schemes}, best_latency={best_latency:.3f}ms)" + ) + + +@dataclass +class PatternCache: + """Pattern cache containing best-performing schemes for patterns with automatic eviction. + + Stores a collection of PatternSchemes that can be used as seeds for autotuning. + Each PatternSchemes contains high-performing insertion schemes for a specific + pattern signature. The cache automatically evicts non-performant schemes based on: + - Error status (schemes with errors are evicted) + - Duplicate schemes (only better-performing duplicate is kept) + - Similarity (similar schemes where only better-performing one is kept) + - Count limit (only top N best schemes are kept per pattern) + """ + + pattern_schemes: list[PatternSchemes] = field(default_factory=list) + minimum_distance: int = 4 # Minimum distance between schemes in cache + max_entries_per_pattern: int = 32 # Maximum number of schemes per pattern (0 = no limit) + + def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: + """Add PatternSchemes to pattern cache with automatic eviction of non-performant entries. + + Merges new schemes with existing schemes for the same pattern, automatically + evicting schemes that are non-performant based on multiple criteria. + + Args: + pattern_schemes: PatternSchemes to add to the cache + """ + if not pattern_schemes or not pattern_schemes.pattern: + return + + pattern_sig = pattern_schemes.pattern_signature + + # Find existing PatternSchemes for this pattern + existing_idx = None + for idx, ps in enumerate(self.pattern_schemes): + if ps.pattern_signature == pattern_sig: + existing_idx = idx + break + + # Collect all schemes (existing + new) + all_schemes = list(pattern_schemes.schemes) + if existing_idx is not None: + all_schemes.extend(self.pattern_schemes[existing_idx].schemes) + + # Filter out schemes with errors and deduplicate by hash + valid_schemes = [s for s in all_schemes if not s.error] + unique_schemes = {} + for scheme in valid_schemes: + scheme_hash = scheme.hash + if ( + scheme_hash not in unique_schemes + or scheme.latency_ms < unique_schemes[scheme_hash].latency_ms + ): + unique_schemes[scheme_hash] = scheme + + # Sort by latency to get best schemes + sorted_schemes = sorted(unique_schemes.values(), key=lambda s: s.latency_ms) + + # Apply distance-based filtering if minimum_distance > 0 + if self.minimum_distance > 0: + filtered_schemes = [] + for scheme in sorted_schemes: + # Check if this scheme is too similar to any already-filtered scheme + too_similar = False + for existing_scheme in filtered_schemes: + distance = scheme.distance(existing_scheme) + if distance < self.minimum_distance: + # Schemes are too similar, keep the better one + if scheme.latency_ms < existing_scheme.latency_ms: + # New scheme is better, remove existing and add new + filtered_schemes.remove(existing_scheme) + break + else: + # Existing scheme is better, skip new one + too_similar = True + break + + if not too_similar: + filtered_schemes.append(scheme) + + sorted_schemes = filtered_schemes + + # Apply count limit if max_entries_per_pattern > 0 + # Keep only the top N best-performing schemes per pattern + if self.max_entries_per_pattern > 0: + sorted_schemes = sorted_schemes[: self.max_entries_per_pattern] + + # Create PatternSchemes with all schemes that passed the eviction criteria + result = PatternSchemes(pattern=pattern_schemes.pattern) + result.schemes = sorted_schemes + + # Replace existing or append new + if existing_idx is not None: + self.pattern_schemes[existing_idx] = result + else: + self.pattern_schemes.append(result) + + def get_pattern_schemes(self, pattern_signature: str) -> PatternSchemes | None: + """Get PatternSchemes for a specific pattern signature. + + Args: + pattern_signature: Pattern signature to lookup + + Returns: + PatternSchemes if found, None otherwise + """ + for ps in self.pattern_schemes: + if ps.pattern_signature == pattern_signature: + return ps + return None + + def has_pattern(self, pattern_signature: str) -> bool: + """Check if pattern cache contains a specific pattern. + + Args: + pattern_signature: Pattern signature to check + + Returns: + True if pattern exists in pattern cache + """ + return any(ps.pattern_signature == pattern_signature for ps in self.pattern_schemes) + + def add_pattern_from_region( + self, region: Region, graph: gs.Graph, quantized_tensors: set[str] + ) -> None: + """Build and add a pattern cache entry from a region in a quantized model. + + Analyzes a region from an already-quantized model to extract its Q/DQ + insertion scheme. This allows capturing known-good quantization strategies + from existing models and using them as seeds for autotuning. + + Args: + region: Region from the quantized model to analyze + graph: ONNX graph containing the region + quantized_tensors: Set of tensor names that have Q/DQ nodes + + Example: + >>> cache = PatternCache() + >>> for region in all_regions: + ... cache.add_pattern_from_region(region, graph, quantized_tensors) + >>> cache.save("learned_patterns.yaml") + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + # Track insertion points + scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + # Analyze node inputs + full_insertion_scheme = pattern.get_full_insertion_scheme(region, graph) + for point in full_insertion_scheme.node_inputs: + temp_scheme = InsertionScheme( + node_inputs=[point], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points: list[ResolvedInsertionPoint] = pattern.matches( + region, graph, temp_scheme + ) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.node_inputs.append(point) + # Analyze region boundaries (for COMPOSITE regions) + if region.type == RegionType.COMPOSITE: + for child_point in full_insertion_scheme.child_region_inputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[child_point], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.child_region_inputs.append(child_point) + # Analyze region outputs + for output_point in full_insertion_scheme.region_outputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[output_point], + latency_ms=float("inf"), + error=False, + ) + temp_insertion_points = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_insertion_points} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.region_outputs.append(output_point) + # Add pattern and scheme to pattern cache + pattern_schemes = PatternSchemes(pattern=pattern, schemes=[scheme]) + self.add_pattern_schemes(pattern_schemes) + num_points = ( + len(scheme.node_inputs) + len(scheme.child_region_inputs) + len(scheme.region_outputs) + ) + logger.debug(f"Added pattern from region {region.id} with {num_points} insertion points") + # Add patterns from child regions + if region.type == RegionType.COMPOSITE: + for child_region in region.get_children(): + self.add_pattern_from_region(child_region, graph, quantized_tensors) + + @property + def num_patterns(self) -> int: + """Get number of patterns in pattern cache.""" + return len(self.pattern_schemes) + + @property + def total_schemes(self) -> int: + """Get total number of schemes across all patterns.""" + return sum(ps.num_schemes for ps in self.pattern_schemes) + + def merge(self, other: "PatternCache", prefer_existing: bool = True) -> None: + """Merge another PatternCache into this one. + + Args: + other: PatternCache to merge + prefer_existing: If True, keep existing patterns when there's a conflict. + If False, overwrite with other's patterns. + """ + for schemes in other.pattern_schemes: + if not self.has_pattern(schemes.pattern_signature) or not prefer_existing: + self.add_pattern_schemes(schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Returns: + Dictionary with 'minimum_distance', 'max_entries_per_pattern', and 'pattern_schemes' keys + """ + return { + "minimum_distance": self.minimum_distance, + "max_entries_per_pattern": self.max_entries_per_pattern, + "pattern_schemes": [ps.to_dict() for ps in self.pattern_schemes], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PatternCache": + """Create PatternCache from serialized dictionary. + + Note: RegionPattern objects are not restored (they're runtime objects). + Only pattern signatures and scheme data are loaded. + + Args: + data: Dictionary containing pattern cache data + + Returns: + Reconstructed PatternCache instance + """ + cache = cls( + minimum_distance=data.get("minimum_distance", 4), + max_entries_per_pattern=data.get("max_entries_per_pattern", 32), + ) + + for ps_data in data.get("pattern_schemes", []): + # Create PatternSchemes without pattern object (pattern=None) + ps = PatternSchemes.from_dict(ps_data, pattern=None) + cache.pattern_schemes.append(ps) + + return cache + + def save(self, output_path: str) -> None: + """Save pattern cache to a YAML file. + + Serializes all pattern schemes and their insertion points to a YAML file + that can be loaded later for seeded autotuning. The format matches the + autotuner state file format for consistency. + + Args: + output_path: File path where the YAML pattern cache file will be written + """ + state = self.to_dict() + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + logger.info( + f"Saved pattern cache → {output_path} ({self.num_patterns} patterns, " + f"{self.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={self.minimum_distance}, " + f"max_per_pattern={self.max_entries_per_pattern}" + ) + + @classmethod + def load(cls, input_path: str) -> "PatternCache": + """Load pattern cache from a YAML file. + + Reads a previously saved pattern cache file and reconstructs all pattern + schemes. The loaded pattern cache can be used to seed autotuning with + known-good insertion schemes. + + Args: + input_path: File path to the YAML pattern cache file to load + + Returns: + PatternCache instance with all pattern schemes loaded + + Raises: + FileNotFoundError: If the input_path doesn't exist + """ + with open(input_path) as f: + state = yaml.safe_load(f) + + cache = cls.from_dict(state) + + logger.info( + f"Loaded pattern cache from {input_path} ({cache.num_patterns} patterns, " + f"{cache.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={cache.minimum_distance}, " + f"max_per_pattern={cache.max_entries_per_pattern}" + ) + + return cache + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"PatternCache(patterns={self.num_patterns}, " + f"schemes={self.total_schemes}, " + f"minimum_distance={self.minimum_distance}, " + f"max_entries_per_pattern={self.max_entries_per_pattern})" + ) + + +@dataclass +class Config: + """Configuration parameters for QDQ autotuning. + + Controls the autotuning process including performance requirements, quantization + parameters, region building, scheme generation, and finetuning behavior. + + Attributes: + # Logging + verbose: Enable detailed logging of autotuning progress (default: False) + + # Performance Requirements + performance_threshold: Minimum speedup ratio to accept a scheme. + 1.0 = no improvement required, 1.02 = 2% improvement (default: 1.02) + + # Quantization Parameters + default_q_scale: Default scale parameter for Q/DQ nodes. Controls quantization + granularity. Typical range: 0.01-0.1 (default: 0.1) + default_q_zero_point: Default zero-point for Q/DQ nodes. Use 0 for signed int8, + 128 for unsigned uint8 (default: 0) + default_quant_type: Quantization type for Q/DQ nodes. Options: "int8" (default), "fp8" + + # Region Builder Settings + maximum_sequence_region_size: Maximum number of nodes in a sequence region during + top-down refinement. Prevents overly large merged regions (default: 10) + minimum_topdown_search_size: Minimum number of nodes in a region to trigger + top-down search during region building (default: 10) + + # Scheme Generation Settings + top_percent_to_mutate: Top percentage of best schemes to use as mutation seeds + during scheme generation. Range: 0.0-1.0 (default: 0.1 = top 10%) + minimum_schemes_to_mutate: Minimum number of schemes to keep as mutation seeds, + even if top_percent_to_mutate results in fewer (default: 10) + maximum_mutations: Maximum number of mutations to apply to a single scheme + during generation (default: 3) + maximum_generation_attempts: Maximum attempts to generate a unique new scheme + before giving up (default: 100) + + # Pattern Cache Settings + pattern_cache_minimum_distance: Minimum edit distance required between schemes in cache. + When adding schemes, if a scheme is too similar (distance < minimum_distance) + to an existing scheme, only the better-performing one is kept (default: 4) + pattern_cache_max_entries_per_pattern: Maximum number of schemes to keep per pattern + in pattern cache. Only the top N best-performing schemes are kept for each pattern. + Use 0 to keep all schemes (default: 32) + """ + + # Logging + verbose: bool = False + + # Performance Requirements + performance_threshold: float = 1.02 + + # Quantization Parameters + default_q_scale: float = 0.1 + default_q_zero_point: int = 0 + default_quant_type: str = "int8" + default_dq_dtype: str = "float32" + + # Region Builder Settings + maximum_sequence_region_size: int = 10 + minimum_topdown_search_size: int = 10 + + # Scheme Generation Settings + top_percent_to_mutate: float = 0.1 + minimum_schemes_to_mutate: int = 10 + maximum_mutations: int = 3 + maximum_generation_attempts: int = 100 + + # Pattern Cache Settings + pattern_cache_minimum_distance: int = 4 + pattern_cache_max_entries_per_pattern: int = 32 diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py index f6a920bec..b64cb23b1 100644 --- a/tests/unit/onnx/quantization/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -159,8 +159,8 @@ def test_region_discovery(self, simple_conv_model): # Regions should be valid for region in autotuner.regions: - assert region.get_id() is not None - assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + assert region.id is not None + assert region.type in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] def test_export_baseline_model(self, simple_conv_model): """Test exporting baseline model without Q/DQ.""" @@ -207,39 +207,33 @@ def test_generate_scheme(self, simple_conv_model): if len(autotuner.regions) == 0: pytest.skip("No regions discovered") - autotuner.submit(10.0) # baseline + autotuner.submit(10.0) region = autotuner.regions[0] autotuner.set_profile_region(region) - # Generate multiple schemes and submit a latency for each - num_generated = 0 - while True: - scheme_idx = autotuner.generate() - if scheme_idx < 0: - break - assert isinstance(scheme_idx, int) - autotuner.submit(10.0 + num_generated * 0.1) # dummy latency - num_generated += 1 - if num_generated >= 5: # cap iterations - break - - assert num_generated > 0, "Expected at least one scheme to be generated" - autotuner.set_profile_region(None, commit=True) - - # Export with Q/DQ and verify Q/DQ nodes are in the model with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: output_path = f.name - try: - autotuner.export_onnx(output_path, insert_qdq=True) - exported = onnx.load(output_path) - node_ops = [n.op_type for n in exported.graph.node] - assert "QuantizeLinear" in node_ops, "Expected QuantizeLinear nodes in exported model" - assert "DequantizeLinear" in node_ops, ( - "Expected DequantizeLinear nodes in exported model" + + has_q = False + has_dq = False + for _ in range(5): + scheme_idx = autotuner.generate() + assert isinstance(scheme_idx, int) + autotuner.submit(10.0 + _ * 0.1) + + autotuner.export_onnx(output_path, insert_qdq=True) + exported = onnx.load(output_path) + node_ops = [n.op_type for n in exported.graph.node] + for node_op in node_ops: + if node_op == "QuantizeLinear": + has_q = True + if node_op == "DequantizeLinear": + has_dq = True + if has_q and has_dq: + break + assert has_q and has_dq, ( + "Expected QuantizeLinear and DequantizeLinear nodes in exported model" ) - finally: - if os.path.exists(output_path): - os.unlink(output_path) def test_submit_latency(self, simple_conv_model): """Test submitting performance measurement.""" @@ -287,12 +281,8 @@ def test_regions_prioritization(self, simple_conv_model): autotuner.initialize(config) # Check that LEAF regions come before non-LEAF - leaf_indices = [ - i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF - ] - non_leaf_indices = [ - i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF - ] + leaf_indices = [i for i, r in enumerate(autotuner.regions) if r.type == RegionType.LEAF] + non_leaf_indices = [i for i, r in enumerate(autotuner.regions) if r.type != RegionType.LEAF] if leaf_indices and non_leaf_indices: # All LEAF should come before non-LEAF From d29abac37f7a181a727215544e55f1c54bac7a77 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 23 Feb 2026 02:41:56 +0000 Subject: [PATCH 08/14] move models to utils Signed-off-by: Will Guo --- .../unit/onnx/quantization/autotune/models.py | 47 ------------------- .../quantization/autotune/test_autotuner.py | 6 +-- 2 files changed, 3 insertions(+), 50 deletions(-) delete mode 100644 tests/unit/onnx/quantization/autotune/models.py diff --git a/tests/unit/onnx/quantization/autotune/models.py b/tests/unit/onnx/quantization/autotune/models.py deleted file mode 100644 index 4090cfef3..000000000 --- a/tests/unit/onnx/quantization/autotune/models.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Shared test ONNX models for autotuner unit tests. - -Model creation functions live here; tests import and call them directly. -""" - -import onnx -from onnx import helper - - -def _create_simple_conv_onnx_model(): - """Build ONNX model: Input -> Conv -> Relu -> Output (minimal for autotuner tests).""" - input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) - output_tensor = helper.make_tensor_value_info( - "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] - ) - conv_node = helper.make_node( - "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" - ) - relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") - graph = helper.make_graph( - [conv_node, relu_node], - "simple_conv", - [input_tensor], - [output_tensor], - initializer=[ - helper.make_tensor( - "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) - ) - ], - ) - return helper.make_model(graph, producer_name="test") diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py index b64cb23b1..26e390a23 100644 --- a/tests/unit/onnx/quantization/autotune/test_autotuner.py +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -23,10 +23,10 @@ import os import tempfile -import models as _test_models import onnx import onnx_graphsurgeon as gs import pytest +from _test_utils.onnx.quantization.autotune.models import _create_simple_conv_onnx_model from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType @@ -34,8 +34,8 @@ @pytest.fixture def simple_conv_model(): - """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py.""" - return _test_models._create_simple_conv_onnx_model() + """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via _test_utils models.""" + return _create_simple_conv_onnx_model() def _create_test_config(): From f656ceef84b16011a5adb2196d53fe8986efbbfc Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 24 Feb 2026 07:28:47 +0000 Subject: [PATCH 09/14] fix conv weight discard bug Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/autotuner.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index f4f1adc63..b864dc185 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -413,6 +413,22 @@ def _exclude_overlapping_insertion_points( full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) assert full_insertion_scheme is not None all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + for ip in all_region_ips: + node = self.graph.nodes[ip.node_index] + # Conv/ConvTranspose inputs and weights must be excluded together + if ( + node.op in ["Conv", "ConvTranspose"] + and ip.input_index == 0 + and len(node.inputs) >= 2 + ): + resolved_insertion_points.discard(ip) + resolved_insertion_points.discard( + ResolvedInsertionPoint( + tensor_name=node.inputs[1].name, + node_index=ip.node_index, + input_index=1, + ) + ) assert isinstance(all_region_ips, set) resolved_insertion_points.difference_update(all_region_ips) if all_region_ips: From 9320792e00a066bab59d5f87744a9945e61439f9 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 24 Feb 2026 14:32:46 +0000 Subject: [PATCH 10/14] quantize gemm/matmul weights together with input Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/autotuner.py | 9 +++------ .../onnx/quantization/autotune/insertion_points.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index b864dc185..e1e71e036 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -28,6 +28,7 @@ import yaml from modelopt.onnx.logging_config import logger +from modelopt.onnx.op_types import is_linear_op from modelopt.onnx.quantization.autotune.common import ( AutotunerNotInitializedError, Config, @@ -415,12 +416,8 @@ def _exclude_overlapping_insertion_points( all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) for ip in all_region_ips: node = self.graph.nodes[ip.node_index] - # Conv/ConvTranspose inputs and weights must be excluded together - if ( - node.op in ["Conv", "ConvTranspose"] - and ip.input_index == 0 - and len(node.inputs) >= 2 - ): + # Conv/ConvTranspose/Gemm/MatMul inputs and weights must be excluded together + if is_linear_op(node.op) and ip.input_index == 0 and len(node.inputs) >= 2: resolved_insertion_points.discard(ip) resolved_insertion_points.discard( ResolvedInsertionPoint( diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py index dd01848dd..393071a65 100644 --- a/modelopt/onnx/quantization/autotune/insertion_points.py +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -40,6 +40,7 @@ get_set_ops, get_value_check_ops, is_fusible_reduction_op, + is_linear_op, ) from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices @@ -127,13 +128,15 @@ def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoi assert self.input_index < len(node.inputs), "Input index out of range" resolved_ips = set() - # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + # Determine which input indices to resolve (include weights for Conv/ConvTranspose/Gemm/MatMul) input_indices = [self.input_index] - if node.op in ["Conv", "ConvTranspose"]: + if is_linear_op(node.op): assert self.input_index == 0, ( - "Conv/ConvTranspose inputs and weights must be quantized together" + "Conv/ConvTranspose/Gemm/MatMul inputs and weights must be quantized together" + ) + assert len(node.inputs) >= 2, ( + "Conv/ConvTranspose/Gemm/MatMul should have at least 2 inputs" ) - assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" input_indices.append(1) for idx in input_indices: @@ -345,7 +348,7 @@ def skip_invalid_insertion_points( for input_idx, inp in enumerate(node.inputs): if hasattr(inp, "name") and inp.name == tensor_name: # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time - if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + if is_linear_op(node.op) and input_idx >= 1: return True # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax if node.op in ["Relu", "Softmax"]: From 9a94dc42b66c42113069144baf188c7fa149c9be Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 25 Feb 2026 09:15:59 +0000 Subject: [PATCH 11/14] resolve copilot comments Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index e1e71e036..89721cc5a 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -16,6 +16,7 @@ """Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" import copy +import dataclasses import functools import os import random @@ -217,6 +218,8 @@ def _seed_from_cache(self, pattern: RegionPattern) -> tuple[PatternSchemes | Non scheme_copy = copy.deepcopy(cached_scheme) scheme_copy.latency_ms = float("inf") scheme_copy.error = False + if hasattr(scheme_copy, "profile_timestamp"): + scheme_copy.profile_timestamp = None pattern_schemes.schemes.append(scheme_copy) num_seeded += 1 logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") @@ -458,8 +461,6 @@ def export_onnx( AutotunerNotInitializedError: If initialize() hasn't been called """ output_desc = output_path if output_path is not None else "" - original_quant_type = self.config.default_quant_type - needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" resolved_insertion_points = set() logger.debug( @@ -467,10 +468,6 @@ def export_onnx( f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" ) - if needs_fp8_conversion: - logger.debug("FP8 conversion: creating INT8 model first") - self.config.default_quant_type = "int8" - if insert_qdq: matched_regions = 0 @@ -501,20 +498,28 @@ def export_onnx( logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") - if insert_qdq and resolved_insertion_points: - self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) + original_quant_type = self.config.default_quant_type - logger.debug("Serializing to ONNX format") - model = gs.export_onnx(graph_copy) + try: + needs_fp8_conversion = original_quant_type == "fp8" + if insert_qdq and resolved_insertion_points: + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + self.config.default_quant_type = "int8" + self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) - if insert_qdq and resolved_insertion_points: - self._fix_zero_point_initializers(model) + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) - if needs_fp8_conversion: - logger.debug("Converting INT8 to FP8") - model = int8_to_fp8(model) + if insert_qdq and resolved_insertion_points: + self._fix_zero_point_initializers(model) + + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + finally: + self.config.default_quant_type = original_quant_type - self.config.default_quant_type = original_quant_type model_bytes = model.SerializeToString() quant_type_str = "baseline" output_dest = "" @@ -631,12 +636,7 @@ def save_state(self, output_path: str) -> None: state = { "baseline_latency_ms": self.baseline_latency_ms, "current_profile_pattern_schemes_signature": current_pattern_sig, - "config": { - "default_q_scale": self.config.default_q_scale, - "default_q_zero_point": self.config.default_q_zero_point, - "default_quant_type": self.config.default_quant_type, - "verbose": self.config.verbose, - }, + "config": dataclasses.asdict(self.config), "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], } @@ -687,15 +687,11 @@ def load_state(self, input_path: str) -> None: if "config" in state: config_data = state["config"] - if "default_q_scale" in config_data: - self.config.default_q_scale = config_data["default_q_scale"] - if "default_q_zero_point" in config_data: - self.config.default_q_zero_point = config_data["default_q_zero_point"] - if "default_quant_type" in config_data: - self.config.default_quant_type = config_data["default_quant_type"] - if "verbose" in config_data: - self.config.verbose = config_data["verbose"] - logger.debug(f"Config merged: quant_type={self.config.default_quant_type}") + if isinstance(config_data, dict): + default_dict = dataclasses.asdict(Config()) + default_dict.update({k: v for k, v in config_data.items() if k in default_dict}) + self.config = Config(**default_dict) + logger.debug(f"Config restored: quant_type={self.config.default_quant_type}") if "patterns" in state: num_loaded_patterns = 0 @@ -1063,6 +1059,7 @@ def _create_qdq_nodes( output_dtype: np.dtype, quant_dtype: np.dtype, q_scale: float, + q_zero_point: int, ) -> tuple[gs.Node, gs.Node]: """Create QuantizeLinear and DequantizeLinear node pair. @@ -1074,6 +1071,7 @@ def _create_qdq_nodes( quant_dtype: Dtype for quantized values quant_type: Quantization type string q_scale: Quantization scale + q_zero_point: Quantization zero point Returns: Tuple of (q_node, dq_node) @@ -1093,7 +1091,7 @@ def _create_qdq_nodes( ) q_scale_values = np.array([q_scale], dtype=scale_dtype) - q_zp_values = np.array([0], dtype=quant_dtype) + q_zp_values = np.array([q_zero_point], dtype=quant_dtype) q_inputs = [ qdq_input, gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), @@ -1109,7 +1107,7 @@ def _create_qdq_nodes( ) dq_scale_values = np.array([q_scale], dtype=scale_dtype) - dq_zp_values = np.array([0], dtype=quant_dtype) + dq_zp_values = np.array([q_zero_point], dtype=quant_dtype) dq_inputs = [ q_node.outputs[0], gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), @@ -1141,10 +1139,13 @@ def _insert_qdq_at_tensors( resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ """ q_scale = self.config.default_q_scale + q_zero_point = self.config.default_q_zero_point quant_type = self.config.default_quant_type quant_dtype = self._resolve_dtype(quant_type, np.int8) - logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") + logger.debug( + f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point={q_zero_point}" + ) resolved_insertion_points = merge_resolved_insertion_points( graph, resolved_insertion_points @@ -1192,6 +1193,7 @@ def _insert_qdq_at_tensors( output_dtype, quant_dtype, q_scale, + q_zero_point, ) graph.nodes.extend([q_node, dq_node]) From c99edab3a0b3f46af4f7267eecad4567f04b5947 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Thu, 26 Feb 2026 02:27:20 +0000 Subject: [PATCH 12/14] fix copilot comments Signed-off-by: Will Guo --- .../onnx/quantization/autotune/autotuner.py | 1204 +---------------- .../quantization/autotune/autotuner_base.py | 966 +++++++++++++ .../quantization/autotune/export_utils.py | 360 +++++ 3 files changed, 1328 insertions(+), 1202 deletions(-) create mode 100644 modelopt/onnx/quantization/autotune/autotuner_base.py create mode 100644 modelopt/onnx/quantization/autotune/export_utils.py diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py index 89721cc5a..69038c59a 100644 --- a/modelopt/onnx/quantization/autotune/autotuner.py +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -15,1212 +15,12 @@ """Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" -import copy -import dataclasses -import functools -import os -import random from collections import Counter, deque -from datetime import datetime, timezone - -import numpy as np -import onnx -import onnx_graphsurgeon as gs -import yaml from modelopt.onnx.logging_config import logger -from modelopt.onnx.op_types import is_linear_op -from modelopt.onnx.quantization.autotune.common import ( - AutotunerNotInitializedError, - Config, - InsertionScheme, - InvalidSchemeError, - PatternCache, - PatternSchemes, - Region, - RegionType, -) -from modelopt.onnx.quantization.autotune.insertion_points import ( - ResolvedInsertionPoint, - merge_resolved_insertion_points, -) -from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.autotune.autotuner_base import QDQAutotunerBase +from modelopt.onnx.quantization.autotune.common import Config, PatternCache, Region, RegionType from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch -from modelopt.onnx.quantization.fp8 import int8_to_fp8 -from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices - -_MUTATION_SPECS = [ - ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), - ( - "child_region_inputs", - "region composite points", - lambda p: (p.region_index, p.input_index), - ), - ( - "region_outputs", - "region output points", - lambda p: (p.region_index, p.node_index, p.output_index), - ), -] - - -def _requires_init(method): - """Decorator that raises AutotunerNotInitializedError if initialize() has not been called.""" - - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - if not self.initialized: - raise AutotunerNotInitializedError( - "QDQAutotunerBase not initialized. Call initialize() first." - ) - return method(self, *args, **kwargs) - - return wrapper - - -class QDQAutotunerBase: - """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" - - _DTYPE_MAP = { - "int8": np.int8, - "uint8": np.uint8, - "float16": np.float16, - "float32": np.float32, - } - - def __init__(self, model: onnx.ModelProto | gs.Graph): - """Initialize the autotuner with an ONNX model. - - Creates a clean copy of the model graph and initializes internal state. - After construction, call initialize() to configure the autotuner, then - use a subclass strategy to populate regions (e.g., QDQAutotuner does this - automatically during initialize()). - - Args: - model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize. - A clean copy is created internally, leaving the original unchanged. - - Raises: - TypeError: If model is neither onnx.ModelProto nor gs.Graph - """ - if isinstance(model, onnx.ModelProto): - self.onnx_model = model - elif isinstance(model, gs.Graph): - self.onnx_model = gs.export_onnx(model) - else: - raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") - - self.graph = self._copy_graph() - self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) - self.regions: list[Region] = [] - self.current_profile_region: Region | None = None - self.profiled_patterns: list[PatternSchemes] = [] - self.current_profile_pattern_schemes: PatternSchemes | None = None - self.current_insertion_scheme_index: int | None = None - self.config = Config() - self.initialized = False - self.baseline_latency_ms: float | None = None - self.pattern_cache: PatternCache | None = None - - logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") - - requires_init = _requires_init - - def initialize( - self, config: Config | None = None, pattern_cache: PatternCache | None = None - ) -> None: - """Initialize autotuning session with configuration and pattern cache. - - Prepares the autotuner for profiling by setting configuration parameters - and optionally loading pattern cache data. This base method resets all profiling - state and sets up the pattern cache storage. - - Args: - config: Autotuning configuration parameters. If None, uses default Config(). - Controls Q/DQ parameters, performance thresholds, and scheme generation. - pattern_cache: Optional PatternCache object for seeding with known-good schemes. - If None, creates a new empty pattern cache for tracking best schemes. - If provided, uses existing schemes to warm-start optimization. - - Raises: - None (safe to call multiple times - will reset state each time) - """ - if config is not None: - self.config = config - - if pattern_cache is None: - pattern_cache = PatternCache( - minimum_distance=self.config.pattern_cache_minimum_distance, - max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, - ) - self.pattern_cache = pattern_cache - - logger.debug( - f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " - f"{pattern_cache.total_schemes} schemes" - ) - - self.initialized = False - self.baseline_latency_ms = None - self.profiled_patterns.clear() - self.regions.clear() - self.current_profile_region = None - self.current_profile_pattern_schemes = None - self.current_insertion_scheme_index = None - - logger.info("Initializing autotuner") - logger.debug( - f"Configuration: q_scale={self.config.default_q_scale}, " - f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" - ) - - self.initialized = True - - def _commit_current_pattern(self, save: bool = True) -> None: - """Save current pattern schemes to profiled_patterns (if save) and clear current state.""" - if save and self.current_profile_pattern_schemes is not None: - num_schemes = len(self.current_profile_pattern_schemes.schemes) - best_scheme = self.current_profile_pattern_schemes.best_scheme - best_latency = best_scheme.latency_ms if best_scheme else float("inf") - - samples_before_best, time_to_best = self._compute_convergence_metrics( - self.current_profile_pattern_schemes.schemes, best_scheme - ) - - logger.info( - f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" - ) - logger.debug( - f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" - ) - if samples_before_best is not None: - logger.debug(f"Convergence: best found at sample {samples_before_best}") - if time_to_best is not None: - logger.debug(f"Time to best: {time_to_best:.2f}s") - self.profiled_patterns.append(self.current_profile_pattern_schemes) - - self.current_profile_region = None - self.current_profile_pattern_schemes = None - self.current_insertion_scheme_index = None - - def _seed_from_cache(self, pattern: RegionPattern) -> tuple[PatternSchemes | None, int]: - """Seed PatternSchemes from pattern cache for the given pattern. Returns (schemes, num_seeded).""" - if self.pattern_cache is None: - return None, 0 - cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) - if cache_schemes is None or len(cache_schemes.schemes) == 0: - logger.debug("No pattern cache entries for this region") - return None, 0 - pattern_schemes = PatternSchemes() - pattern_schemes.pattern = pattern - num_seeded = 0 - for cached_scheme in cache_schemes.schemes: - scheme_copy = copy.deepcopy(cached_scheme) - scheme_copy.latency_ms = float("inf") - scheme_copy.error = False - if hasattr(scheme_copy, "profile_timestamp"): - scheme_copy.profile_timestamp = None - pattern_schemes.schemes.append(scheme_copy) - num_seeded += 1 - logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") - return pattern_schemes, num_seeded - - @_requires_init - def set_profile_region(self, region: Region | None, commit: bool = True) -> None: - """Set the target region for profiling and scheme generation. - - This method manages the profiling workflow: - 1. If commit=True: Saves current schemes to profiled_patterns - 2. Creates a RegionPattern from the new region's structure - 3. For pattern-based: tries to seed schemes from pattern cache if available - 4. Sets as current for generate() and submit() calls - - Pass region=None to clear the current profile target without setting a new one. - - Args: - region: The region to profile next (None to clear current target) - commit: If True, commit current schemes to profiled_patterns - before switching. Set to False during initialization. - - Raises: - AutotunerNotInitializedError: If initialize() hasn't been called - """ - if commit or region is None: - self._commit_current_pattern(save=commit) - if region is None: - return - - if region not in self.regions: - raise ValueError(f"Region {region.id} not found in regions") - - region_pattern = RegionPattern.from_region(region, self.graph) - - if self._is_region_profiled(region): - logger.info(f"Skipping region {region.id} (pattern already profiled)") - logger.debug(f"Pattern signature: {region_pattern.signature}") - return - - pattern_schemes, num_seeded = self._seed_from_cache(region_pattern) - if pattern_schemes is None: - pattern_schemes = PatternSchemes() - pattern_schemes.pattern = region_pattern - logger.debug("Initialized with empty scheme collection") - - self.current_profile_region = region - self.current_profile_pattern_schemes = pattern_schemes - - mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" - logger.info( - f"Profiling region {region.id} [level {region.level}, size" - f"{region.get_size_of_region_and_descendants()}, {mode_info}]" - ) - logger.debug(f"Pattern signature: {region_pattern.signature}") - - @_requires_init - def generate(self) -> int: - """Generate a new Q/DQ insertion scheme for the current pattern or region. - - Creates a new InsertionScheme by mutating the top-performing schemes: - 1. Checks if there are any cached schemes (error=False, latency_ms=inf) - 2. If cached schemes exist, picks one to re-profile - 3. Otherwise, generates a new scheme by mutation - 4. Selects a random scheme from the top 10 performers - 5. Mutates it by adding/removing insertion points - 6. Ensures the new scheme is unique (different from existing schemes) - 7. Adds the scheme to current_profile_pattern_schemes - - """ - if self.current_profile_pattern_schemes is None: - raise InvalidSchemeError("No region selected. Call set_profile_region() first.") - - pattern_schemes = self.current_profile_pattern_schemes - cached_schemes = [ - (idx, scheme) - for idx, scheme in enumerate(pattern_schemes.schemes) - if not scheme.is_profiled - ] - - if cached_schemes: - scheme_index, cached_scheme_data = cached_schemes[0] - num_node_points = len(cached_scheme_data.node_inputs) - num_region_composite_points = len(cached_scheme_data.child_region_inputs) - num_region_output_points = len(cached_scheme_data.region_outputs) - total_points = num_node_points + num_region_composite_points + num_region_output_points - - logger.info( - f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" - ) - logger.debug( - f"Cached scheme breakdown: {num_node_points} node input, " - f"{num_region_composite_points} region composite, " - f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" - ) - - self.current_insertion_scheme_index = scheme_index - return self.current_insertion_scheme_index - - known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} - max_attempts = getattr(self.config, "maximum_generation_attempts", 100) - - logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") - - for attempts in range(max_attempts): - new_scheme = self._generate_next_insertion_sample() - if new_scheme.hash not in known_schemes and not new_scheme.error: - pattern_schemes.schemes.append(new_scheme) - scheme_index = len(pattern_schemes.schemes) - 1 - num_node_points = len(new_scheme.node_inputs) - num_region_composite_points = len(new_scheme.child_region_inputs) - num_region_output_points = len(new_scheme.region_outputs) - total_points = ( - num_node_points + num_region_composite_points + num_region_output_points - ) - - logger.info( - f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" - ) - logger.debug( - f"Scheme breakdown: {num_node_points} node input, " - f"{num_region_composite_points} region composite, " - f"{num_region_output_points} region output points " - f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" - ) - - self.current_insertion_scheme_index = scheme_index - return self.current_insertion_scheme_index - - logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") - return -1 - - def _resolve_scheme_for_region( - self, region: Region, best: bool - ) -> tuple[InsertionScheme | None, RegionPattern]: - """Resolve the insertion scheme to use for a region from profiled/current/cache. - - Args: - region: The region to resolve the scheme for - best: If True, return the best scheme for the region - - Returns: - tuple[InsertionScheme | None, RegionPattern]: The scheme and pattern for the region - """ - pattern = RegionPattern.from_region(region, self.graph) - logger.debug(f"Region {region.id} (level {region.level})") - logger.debug(f" → Pattern signature: {pattern.signature}") - - matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) - current_scheme = matched.best_scheme if matched else None - - if matched: - if current_scheme: - logger.debug( - f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" - ) - else: - logger.debug(" → Matched profiled pattern but no valid schemes") - - if current_scheme is None: - pattern_schemes = self.current_profile_pattern_schemes - if pattern_schemes is None or pattern != pattern_schemes.pattern: - pass - elif best: - current_scheme = pattern_schemes.best_scheme - else: - scheme_index = self.current_insertion_scheme_index - if scheme_index is not None: - assert scheme_index < len(pattern_schemes.schemes), ( - f"Invalid scheme index: {scheme_index}" - ) - current_scheme = pattern_schemes.schemes[scheme_index] - logger.debug(f" → Using current pattern scheme #{scheme_index}") - - if current_scheme is None and self.pattern_cache is not None: - cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) - if cache_schemes is not None: - schemes = cache_schemes.schemes - if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: - current_scheme = schemes[0] - logger.debug(" → Using imported pattern from cache") - - if current_scheme is None: - logger.debug(" → No scheme available, skipping") - - return current_scheme, pattern - - def _exclude_overlapping_insertion_points( - self, - resolved_insertion_points: set[ResolvedInsertionPoint], - region: Region, - pattern: RegionPattern, - ) -> None: - """Remove this region's full insertion points from resolved set so they can be replaced.""" - full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) - assert full_insertion_scheme is not None - all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) - for ip in all_region_ips: - node = self.graph.nodes[ip.node_index] - # Conv/ConvTranspose/Gemm/MatMul inputs and weights must be excluded together - if is_linear_op(node.op) and ip.input_index == 0 and len(node.inputs) >= 2: - resolved_insertion_points.discard(ip) - resolved_insertion_points.discard( - ResolvedInsertionPoint( - tensor_name=node.inputs[1].name, - node_index=ip.node_index, - input_index=1, - ) - ) - assert isinstance(all_region_ips, set) - resolved_insertion_points.difference_update(all_region_ips) - if all_region_ips: - logger.debug(f" → Excluded {len(all_region_ips)} overlapping insertion points") - - @_requires_init - def export_onnx( - self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False - ) -> bytes: - """Export ONNX model with Q/DQ nodes inserted according to tested schemes. - - This method creates a modified version of the model by: - 1. For each region, finding the matching pattern - 2. Applying the best scheme for profiled patterns - 3. Applying the current scheme for the active profile pattern - 4. Resolving pattern-relative insertion points to actual tensor names - 5. Inserting Q/DQ pairs at the resolved locations - 6. Converting to FP8 if needed (always creates INT8 first, then converts) - - Args: - output_path: Optional file path where the modified ONNX model will be saved. - If None, the model is not saved to disk and only bytes are returned. - insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model - (useful for baseline measurements) - - Returns: - bytes: Serialized ONNX model as bytes - - Raises: - AutotunerNotInitializedError: If initialize() hasn't been called - """ - output_desc = output_path if output_path is not None else "" - resolved_insertion_points = set() - - logger.debug( - f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " - f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" - ) - - if insert_qdq: - matched_regions = 0 - - logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") - - for region in self.regions: - current_scheme, pattern = self._resolve_scheme_for_region(region, best) - if current_scheme is None: - continue - - self._exclude_overlapping_insertion_points( - resolved_insertion_points, region, pattern - ) - - new_ips = pattern.matches(region, self.graph, current_scheme) - if new_ips: - resolved_insertion_points.update(new_ips) - matched_regions += 1 - logger.debug(f" → Added {len(new_ips)} insertion points") - - logger.debug( - f"Matched {matched_regions}/{len(self.regions)} regions, " - f"total {len(resolved_insertion_points)} unique insertion points" - ) - - graph_copy = self._copy_graph() - unique_tensors = len(resolved_insertion_points) - - logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") - - original_quant_type = self.config.default_quant_type - - try: - needs_fp8_conversion = original_quant_type == "fp8" - if insert_qdq and resolved_insertion_points: - if needs_fp8_conversion: - logger.debug("FP8 conversion: creating INT8 model first") - self.config.default_quant_type = "int8" - self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) - - logger.debug("Serializing to ONNX format") - model = gs.export_onnx(graph_copy) - - if insert_qdq and resolved_insertion_points: - self._fix_zero_point_initializers(model) - - if needs_fp8_conversion: - logger.debug("Converting INT8 to FP8") - model = int8_to_fp8(model) - finally: - self.config.default_quant_type = original_quant_type - - model_bytes = model.SerializeToString() - quant_type_str = "baseline" - output_dest = "" - - if insert_qdq: - quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" - - if output_path is not None: - onnx.save(model, output_path) - output_dest = f" → {output_path}" - - logger.info( - f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs {output_dest}" - ) - return model_bytes - - @_requires_init - def submit(self, latency_ms: float, success: bool = True) -> None: - """Submit performance measurement for the most recently generated scheme. - - This method records the measured latency and manages the optimization state: - - Args: - latency_ms: Measured latency in milliseconds (must be > 0) - success: Whether the measurement succeeded. If False, sets scheme.error=True, - logs a warning, and skips speedup calculation. - - Raises: - AutotunerNotInitializedError: If initialize() hasn't been called - InvalidSchemeError: If no pattern or region is set, or no schemes have been generated - """ - if self.baseline_latency_ms is None: - self.baseline_latency_ms = latency_ms - logger.info(f"Baseline latency: {latency_ms:.3f} ms") - return - - if self.current_profile_pattern_schemes is None: - raise InvalidSchemeError( - "No pattern or region selected. Call set_profile_region() first." - ) - - schemes_collection = self.current_profile_pattern_schemes - if not schemes_collection.schemes: - raise InvalidSchemeError("No schemes available. Call generate() first.") - - pattern_schemes = schemes_collection - - if self.current_insertion_scheme_index is not None: - scheme_index = self.current_insertion_scheme_index - if scheme_index >= len(pattern_schemes.schemes): - raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") - scheme = pattern_schemes.schemes[scheme_index] - else: - scheme = pattern_schemes.schemes[-1] - scheme_index = len(pattern_schemes.schemes) - 1 - - scheme.latency_ms = latency_ms - scheme.error = not success - scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() - display_index = scheme_index + 1 - - if not success: - logger.warning( - f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" - ) - logger.debug("Marking scheme with error flag") - return - - speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 - - logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") - logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") - - old_best = ( - pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") - ) - pattern_schemes.schemes.sort( - key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") - ) - new_best = ( - pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") - ) - - if new_best < old_best: - new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 - logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") - logger.debug(f"Previous best: {old_best:.3f} ms") - - if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: - self.pattern_cache.add_pattern_schemes(pattern_schemes) - logger.debug( - f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " - f"{self.pattern_cache.total_schemes} schemes" - ) - - def save_state(self, output_path: str) -> None: - """Save complete autotuner state to a YAML file for later reuse. - - Serializes all optimization results including: - - Baseline latency measurement - - All profiled patterns with their signatures - - All generated schemes with insertion points and latencies - - Configuration parameters - - Current profiling state - - Args: - output_path: File path where the YAML state file will be written. - Pattern cache will be saved to _pattern_cache.yaml - """ - current_pattern_sig = None - if self.current_profile_pattern_schemes is not None: - current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature - - state = { - "baseline_latency_ms": self.baseline_latency_ms, - "current_profile_pattern_schemes_signature": current_pattern_sig, - "config": dataclasses.asdict(self.config), - "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], - } - - with open(output_path, "w") as f: - yaml.dump(state, f, default_flow_style=False, sort_keys=False) - - num_patterns = len(self.profiled_patterns) - total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) - - logger.info( - f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" - ) - logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") - - if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: - base_path, ext = os.path.splitext(output_path) - cache_path = f"{base_path}_pattern_cache{ext}" - self.pattern_cache.save(cache_path) - - logger.info(f"Saved pattern cache → {cache_path}") - logger.debug( - f"Cache: {self.pattern_cache.num_patterns} patterns, " - f"{self.pattern_cache.total_schemes} schemes" - ) - - @_requires_init - def load_state(self, input_path: str) -> None: - """Load autotuner state from a previously saved YAML file. - - Restores optimization results from a previous session: - 1. Matches saved patterns to current model's patterns by signature - 2. Loads all schemes with their insertion points and latencies (including unmeasured ones) - 3. Restores baseline latency and configuration - - Args: - input_path: File path to the YAML state file to load - - Raises: - AutotunerNotInitializedError: If initialize() hasn't been called - FileNotFoundError: If the input_path doesn't exist - """ - with open(input_path) as f: - state = yaml.safe_load(f) - - if state.get("baseline_latency_ms") is not None: - self.baseline_latency_ms = state["baseline_latency_ms"] - logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") - - if "config" in state: - config_data = state["config"] - if isinstance(config_data, dict): - default_dict = dataclasses.asdict(Config()) - default_dict.update({k: v for k, v in config_data.items() if k in default_dict}) - self.config = Config(**default_dict) - logger.debug(f"Config restored: quant_type={self.config.default_quant_type}") - - if "patterns" in state: - num_loaded_patterns = 0 - num_loaded_schemes = 0 - - for pattern_data in state["patterns"]: - try: - pattern_schemes = PatternSchemes.from_dict(pattern_data) - - if pattern_schemes.schemes: - self.profiled_patterns.append(pattern_schemes) - num_loaded_patterns += 1 - num_loaded_schemes += len(pattern_schemes.schemes) - else: - logger.debug( - f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." - ) - - except Exception as e: # noqa: PERF203 - logger.warning(f"Failed to load pattern: {e}") - continue - - logger.info( - f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " - f"{num_loaded_schemes} schemes)" - ) - - base_path, ext = os.path.splitext(input_path) - cache_path = f"{base_path}_pattern_cache{ext}" - - if os.path.exists(cache_path): - try: - loaded_cache = PatternCache.load(cache_path) - - if self.pattern_cache is not None: - for pattern_schemes in loaded_cache.pattern_schemes: - self.pattern_cache.add_pattern_schemes(pattern_schemes) - else: - self.pattern_cache = loaded_cache - logger.info( - f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " - f"{loaded_cache.total_schemes} schemes)" - ) - except Exception as e: - logger.warning(f"Failed to load pattern cache: {e}") - else: - logger.debug(f"No pattern cache file at {cache_path}") - - @_requires_init - def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: - """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. - - Analyzes the current model's regions against the provided quantized tensors - to extract Q/DQ insertion patterns. For each region, creates a pattern cache - entry that captures which insertion points correspond to the quantized tensors. - These cached patterns can then be used as seeds for future autotuning sessions. - - Args: - quantized_tensors: Set or list of tensor names that are quantized - (i.e., tensors that have Q/DQ nodes applied to them) - - Raises: - AutotunerNotInitializedError: If initialize() hasn't been called - """ - if isinstance(quantized_tensors, list): - quantized_tensors = set(quantized_tensors) - - logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") - logger.debug(f"Processing {len(self.regions)} regions") - - if self.pattern_cache is None: - logger.warning("Pattern cache not initialized, skipping import") - return - - patterns_before = self.pattern_cache.num_patterns - schemes_before = self.pattern_cache.total_schemes - - for region in self.regions: - self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) - - patterns_added = self.pattern_cache.num_patterns - patterns_before - schemes_added = self.pattern_cache.total_schemes - schemes_before - - logger.info( - f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" - ) - logger.debug( - f"Total cache: {self.pattern_cache.num_patterns} patterns, " - f"{self.pattern_cache.total_schemes} schemes" - ) - - def _compute_convergence_metrics( - self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None - ) -> tuple[int | None, float | None]: - """Compute convergence metrics for a collection of schemes. - - Analyzes when the best scheme was discovered during the profiling process - by sorting schemes by their profile timestamps and finding the position - of the best scheme. - - Args: - schemes: List of insertion schemes with profile timestamps - best_scheme: The best performing scheme (lowest latency) - - Returns: - Tuple of (samples_before_best, time_to_best) where: - - samples_before_best: Number of samples tested before finding best (0-based index) - - time_to_best: Time in seconds from first sample to best sample - Both values are None if metrics cannot be computed (e.g., missing timestamps) - """ - samples_before_best = None - time_to_best = None - - if not best_scheme or not best_scheme.profile_timestamp: - return samples_before_best, time_to_best - - schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] - - if not schemes_with_time: - return samples_before_best, time_to_best - - schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") - - try: - best_position = next( - i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash - ) - samples_before_best = best_position - - first_ts = schemes_with_time[0].profile_timestamp - best_ts = best_scheme.profile_timestamp - assert first_ts is not None and best_ts is not None - first_timestamp = datetime.fromisoformat(first_ts) - best_timestamp = datetime.fromisoformat(best_ts) - time_to_best = (best_timestamp - first_timestamp).total_seconds() - except (StopIteration, ValueError): - pass - - return samples_before_best, time_to_best - - def _is_region_profiled(self, region: Region) -> bool: - """Check if a region's pattern has already been fully profiled.""" - return any( - p.pattern is not None - and p.pattern.matches(region, self.graph) - and all(s.is_profiled for s in p.schemes) - for p in self.profiled_patterns - ) - - def _mutate_insertion_points( - self, base_points, all_points, point_type: str, max_mutations: int - ) -> list: - """Mutate a set of insertion points by adding, removing, or both.""" - key_fn = { - "node input points": lambda p: (p.node_index, p.input_index), - "region composite points": lambda p: (p.region_index, p.input_index), - "region output points": lambda p: (p.region_index, p.node_index, p.output_index), - }.get(point_type) - - if not key_fn: - return [] - - current_points = set(base_points) - initial_count = len(current_points) - mutation_type = random.choice(["add", "remove", "both"]) - - if mutation_type in ["add", "both"] and len(current_points) < len(all_points): - all_keys = {key_fn(p) for p in all_points} - available_keys = all_keys - current_points - if available_keys: - max_add = min(max_mutations, len(available_keys)) - num_to_add = random.randint(1, max_add) - to_add = random.sample(list(available_keys), num_to_add) - current_points.update(to_add) - - if mutation_type in ["remove", "both"] and current_points: - max_remove = min(max_mutations, len(current_points)) - num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 - num_to_remove = min(num_to_remove, len(current_points)) - to_remove = random.sample(list(current_points), num_to_remove) - for p in to_remove: - current_points.discard(p) - - logger.debug( - f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" - ) - - return [p for p in all_points if key_fn(p) in current_points] - - def _generate_next_insertion_sample(self) -> InsertionScheme: - """Generate a new insertion scheme by mutating top performers. - - This is the core scheme generation algorithm: - 1. Identifies top schemes by latency - 2. Randomly selects one as the base - 3. Mutates node input insertion points (add, remove, or both) - 4. Mutates region composite insertion points (child boundaries) - 5. Mutates region output insertion points - 6. Returns new unique scheme - - **Mutation Strategy:** - - Node input points: Add/remove 1-3 insertion points - - Region composite points: Add/remove 1-3 boundary points - - Region output points: Add/remove 1-3 output points - - Mutation type chosen randomly: 'add', 'remove', or 'both' - - **Baseline Case:** - If no schemes exist yet, returns an empty baseline scheme. - - Returns: - New InsertionScheme with mutated insertion points. - Returns empty scheme if no region is set or no candidates exist. - """ - if self.current_profile_region is None: - return InsertionScheme() - - if self.current_profile_pattern_schemes is not None: - schemes_collection = self.current_profile_pattern_schemes - else: - return InsertionScheme() - - region = self.current_profile_region - pattern_schemes = schemes_collection - - if not isinstance(schemes_collection, PatternSchemes) or schemes_collection.pattern is None: - return InsertionScheme() - pattern = schemes_collection.pattern - full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) - - logger.debug( - f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " - f"{len(full_insertion_scheme.child_region_inputs)} region composite, " - f"{len(full_insertion_scheme.region_outputs)} region output" - ) - - top_percent = getattr(self.config, "top_percent_to_mutate", 0.1) - minimum_schemes = getattr(self.config, "minimum_schemes_to_mutate", 1) - - measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] - measured_schemes.sort(key=lambda s: s.latency_ms) - - num_top_schemes = max( - int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) - ) - top_schemes = measured_schemes[:num_top_schemes] - - if len(top_schemes) == 0: - logger.debug("No measured schemes yet, generating baseline (empty) scheme") - return InsertionScheme() - - base_scheme = random.choice(top_schemes) - total_base_points = ( - len(base_scheme.node_inputs) - + len(base_scheme.child_region_inputs) - + len(base_scheme.region_outputs) - ) - logger.debug( - f"Mutating from top {len(top_schemes)} schemes: " - f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" - ) - - max_mutations = getattr(self.config, "maximum_mutations", 3) - scheme = InsertionScheme() - - for attr, point_type, key_fn in _MUTATION_SPECS: - base_points = {key_fn(p) for p in getattr(base_scheme, attr)} - setattr( - scheme, - attr, - self._mutate_insertion_points( - base_points, - getattr(full_insertion_scheme, attr), - point_type, - max_mutations, - ), - ) - - return scheme - - def _copy_graph(self) -> gs.Graph: - """Create an independent copy of the computation graph.""" - new_graph = gs.import_onnx(self.onnx_model) - new_graph.toposort() - return new_graph - - def _resolve_dtype(self, dtype_str: str, default: np.dtype = np.int8) -> np.dtype: - """Resolve a dtype string (quant or DQ output) to a numpy dtype.""" - if dtype_str == "fp8": - try: - return np.dtype(np.float8_e4m3fn) - except (AttributeError, TypeError): - logger.warning( - "FP8 dtype not available (requires numpy >= 2.0), " - "using uint8 as placeholder. Note: This may not produce " - "correct results without proper FP8 support." - ) - return np.uint8 - if hasattr(np, "bfloat16") and dtype_str == "bfloat16": - return np.bfloat16 - if dtype_str in self._DTYPE_MAP: - return self._DTYPE_MAP[dtype_str] - logger.warning(f"Unknown dtype '{dtype_str}', using default {default}") - return default - - def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: - """Build mapping from tensor names to tensor objects.""" - tensor_map = {t.name: t for t in graph.inputs if hasattr(t, "name") and t.name} - for node in graph.nodes: - for t in node.inputs: - if hasattr(t, "name") and t.name: - tensor_map[t.name] = t - for t in node.outputs: - if isinstance(t, gs.Constant) and hasattr(t, "name") and t.name: - tensor_map[t.name] = t - return tensor_map - - def _get_tensor_metadata( - self, tensor: gs.Tensor, is_constant: bool - ) -> tuple[tuple | None, np.dtype]: - """Extract shape and dtype metadata from a tensor.""" - default_dtype = self._resolve_dtype(self.config.default_dq_dtype, np.float32) - - if is_constant and hasattr(tensor, "values") and tensor.values is not None: - return tensor.values.shape, tensor.values.dtype - elif hasattr(tensor, "shape"): - dtype = ( - tensor.dtype - if hasattr(tensor, "dtype") and tensor.dtype is not None - else default_dtype - ) - return tensor.shape, dtype - return None, default_dtype - - def _fix_zero_point_initializers(self, model: onnx.ModelProto) -> None: - """Fix INT8 zero_point initializers to use int32_data instead of raw_data.""" - fixed_count = 0 - - for initializer in model.graph.initializer: - if ( - "_zp_" in initializer.name - and initializer.data_type == onnx.TensorProto.INT8 - and len(initializer.raw_data) > 0 - and len(initializer.int32_data) == 0 - ): - np_array = onnx.numpy_helper.to_array(initializer) - int32_values = np_array.astype(np.int32).flatten().tolist() - - new_tensor = onnx.helper.make_tensor( - initializer.name, - onnx.TensorProto.INT8, - list(initializer.dims), - int32_values, - ) - initializer.CopyFrom(new_tensor) - fixed_count += 1 - - if fixed_count > 0: - logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") - - def _create_qdq_nodes( - self, - tensor_name: str, - qdq_input: gs.Tensor, - output_shape: tuple | None, - output_dtype: np.dtype, - quant_dtype: np.dtype, - q_scale: float, - q_zero_point: int, - ) -> tuple[gs.Node, gs.Node]: - """Create QuantizeLinear and DequantizeLinear node pair. - - Args: - tensor_name: Name of the tensor being quantized - qdq_input: Input tensor to the Q node - output_shape: Shape for Q/DQ outputs (may be None) - output_dtype: Dtype for DQ output (also used for scale dtype) - quant_dtype: Dtype for quantized values - quant_type: Quantization type string - q_scale: Quantization scale - q_zero_point: Quantization zero point - - Returns: - Tuple of (q_node, dq_node) - """ - # Create unique names for Q/DQ nodes - q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") - dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") - # Determine scale dtype from output_dtype (fp16/tf32/fp32) - # Scale should match the precision of the original I/O tensor - dtype_map = {"float16": np.float16, "float32": np.float32} - if hasattr(np, "bfloat16"): - dtype_map["bfloat16"] = np.bfloat16 - scale_dtype = dtype_map.get(np.dtype(output_dtype).name, np.float32) - - logger.debug( - f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" - ) - - q_scale_values = np.array([q_scale], dtype=scale_dtype) - q_zp_values = np.array([q_zero_point], dtype=quant_dtype) - q_inputs = [ - qdq_input, - gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), - gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), - ] - q_node = gs.Node( - op="QuantizeLinear", - name=q_name, - inputs=q_inputs, - outputs=[ - gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape) - ], - ) - - dq_scale_values = np.array([q_scale], dtype=scale_dtype) - dq_zp_values = np.array([q_zero_point], dtype=quant_dtype) - dq_inputs = [ - q_node.outputs[0], - gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), - gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), - ] - dq_node = gs.Node( - op="DequantizeLinear", - name=dq_name, - inputs=dq_inputs, - outputs=[ - gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape) - ], - ) - - return q_node, dq_node - - def _insert_qdq_at_tensors( - self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] - ) -> None: - """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations. - - This is the main entry point for Q/DQ insertion. It: - 1. Builds tensor map and tensor-to-users map for efficient lookup - 2. Processes each resolved insertion point to insert Q/DQ nodes - 3. Handles two insertion modes based on node_index - - Args: - graph: Graph to modify in-place - resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ - """ - q_scale = self.config.default_q_scale - q_zero_point = self.config.default_q_zero_point - quant_type = self.config.default_quant_type - quant_dtype = self._resolve_dtype(quant_type, np.int8) - - logger.debug( - f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point={q_zero_point}" - ) - - resolved_insertion_points = merge_resolved_insertion_points( - graph, resolved_insertion_points - ) - - tensor_map = self._build_tensor_map(graph) - tensor_users_map = get_tensor_consumer_node_indices(graph) - logger.debug( - f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" - ) - - for insertion_point in resolved_insertion_points: - tensor_name = insertion_point.tensor_name - node_index = insertion_point.node_index - input_index = insertion_point.input_index - - original_tensor = tensor_map[tensor_name] - if node_index is not None: - assert node_index < len(graph.nodes), "Node index out of range" - target_node = graph.nodes[node_index] - assert input_index is not None, "Input index must be set when node index is set" - assert input_index < len(target_node.inputs), ( - f"Input index out of range for node {target_node.name}" - ) - original_tensor = target_node.inputs[input_index] - assert tensor_name == original_tensor.name, ( - f"Tensor name mismatch for node {target_node.name} input {input_index}" - ) - else: - assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map" - assert input_index is None, "Input index must be None when node index is None" - - is_constant = isinstance(original_tensor, gs.Constant) - output_shape, output_dtype = self._get_tensor_metadata(original_tensor, is_constant) - - unique_suffix = "qdq" - if node_index is not None: - unique_suffix = f"n{node_index}_i{input_index}" - unique_tensor_name = f"{tensor_name}_{unique_suffix}" - - q_node, dq_node = self._create_qdq_nodes( - unique_tensor_name, - original_tensor, - output_shape, - output_dtype, - quant_dtype, - q_scale, - q_zero_point, - ) - - graph.nodes.extend([q_node, dq_node]) - - if node_index is not None: - target_node.inputs[input_index] = dq_node.outputs[0] - logger.debug( - f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " - f"({target_node.name}) input #{input_index}" - ) - else: - users = tensor_users_map[tensor_name] - for user_index in users: - user_node = graph.nodes[user_index] - for i, input_tensor in enumerate(user_node.inputs): - if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: - user_node.inputs[i] = dq_node.outputs[0] - break - logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") - - logger.debug("Running graph cleanup and topological sort") - try: - graph.cleanup().toposort() - logger.debug("Graph cleanup completed") - except Exception as e: - logger.error(f"Graph cleanup failed: {e}") - raise RuntimeError(f"Graph cleanup failed after Q/DQ insertion: {e}") from e class QDQAutotuner(QDQAutotunerBase): diff --git a/modelopt/onnx/quantization/autotune/autotuner_base.py b/modelopt/onnx/quantization/autotune/autotuner_base.py new file mode 100644 index 000000000..a519d7c61 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/autotuner_base.py @@ -0,0 +1,966 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base implementation for pattern-based Q/DQ insertion optimization in ONNX models. + +This module defines QDQAutotunerBase, which implements the core autotuning workflow: +region-aware scheme resolution, Q/DQ insertion point matching, scheme generation via +mutation, and export (delegating to export_utils for actual Q/DQ insertion and ONNX +serialization). Subclasses such as QDQAutotuner add region discovery (e.g., automatic +search around compute-intensive ops); this base does not populate regions itself and +expects them to be set by a subclass or caller before profiling and export. +""" + +import copy +import dataclasses +import functools +import os +import random +from datetime import datetime, timezone + +import onnx +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.op_types import is_linear_op +from modelopt.onnx.quantization.autotune.common import ( + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, +) +from modelopt.onnx.quantization.autotune.export_utils import export_qdq_onnx +from modelopt.onnx.quantization.autotune.insertion_points import ResolvedInsertionPoint +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +_MUTATION_SPECS = [ + ("node_inputs", "node input points", lambda p: (p.node_index, p.input_index)), + ( + "child_region_inputs", + "region composite points", + lambda p: (p.region_index, p.input_index), + ), + ( + "region_outputs", + "region output points", + lambda p: (p.region_index, p.node_index, p.output_index), + ), +] + + +def _requires_init(method): + """Decorator that raises AutotunerNotInitializedError if initialize() has not been called.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + return method(self, *args, **kwargs) + + return wrapper + + +class QDQAutotunerBase: + """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + + def __init__(self, model: onnx.ModelProto | gs.Graph): + """Initialize the autotuner with an ONNX model. + + Creates a clean copy of the model graph and initializes internal state. + After construction, call initialize() to configure the autotuner, then + use a subclass strategy to populate regions (e.g., QDQAutotuner does this + automatically during initialize()). + + Args: + model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize. + A clean copy is created internally, leaving the original unchanged. + + Raises: + TypeError: If model is neither onnx.ModelProto nor gs.Graph + """ + if isinstance(model, onnx.ModelProto): + self.onnx_model = model + elif isinstance(model, gs.Graph): + self.onnx_model = gs.export_onnx(model) + else: + raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") + + self.graph = self._copy_graph() + self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.regions: list[Region] = [] + self.current_profile_region: Region | None = None + self.profiled_patterns: list[PatternSchemes] = [] + self.current_profile_pattern_schemes: PatternSchemes | None = None + self.current_insertion_scheme_index: int | None = None + self.config = Config() + self.initialized = False + self.baseline_latency_ms: float | None = None + self.pattern_cache: PatternCache | None = None + + logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + + requires_init = _requires_init + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuning session with configuration and pattern cache. + + Prepares the autotuner for profiling by setting configuration parameters + and optionally loading pattern cache data. This base method resets all profiling + state and sets up the pattern cache storage. + + Args: + config: Autotuning configuration parameters. If None, uses default Config(). + Controls Q/DQ parameters, performance thresholds, and scheme generation. + pattern_cache: Optional PatternCache object for seeding with known-good schemes. + If None, creates a new empty pattern cache for tracking best schemes. + If provided, uses existing schemes to warm-start optimization. + + Raises: + None (safe to call multiple times - will reset state each time) + """ + if config is not None: + self.config = config + + if pattern_cache is None: + pattern_cache = PatternCache( + minimum_distance=self.config.pattern_cache_minimum_distance, + max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, + ) + self.pattern_cache = pattern_cache + + logger.debug( + f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " + f"{pattern_cache.total_schemes} schemes" + ) + + self.initialized = False + self.baseline_latency_ms = None + self.profiled_patterns.clear() + self.regions.clear() + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + logger.info("Initializing autotuner") + logger.debug( + f"Configuration: q_scale={self.config.default_q_scale}, " + f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" + ) + + self.initialized = True + + def _commit_current_pattern(self, save: bool = True) -> None: + """Save current pattern schemes to profiled_patterns (if save) and clear current state.""" + if save and self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + def _seed_from_cache(self, pattern: RegionPattern) -> tuple[PatternSchemes | None, int]: + """Seed PatternSchemes from pattern cache for the given pattern. Returns (schemes, num_seeded).""" + if self.pattern_cache is None: + return None, 0 + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is None or len(cache_schemes.schemes) == 0: + logger.debug("No pattern cache entries for this region") + return None, 0 + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = pattern + num_seeded = 0 + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + if hasattr(scheme_copy, "profile_timestamp"): + scheme_copy.profile_timestamp = None + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + return pattern_schemes, num_seeded + + @_requires_init + def set_profile_region(self, region: Region | None, commit: bool = True) -> None: + """Set the target region for profiling and scheme generation. + + This method manages the profiling workflow: + 1. If commit=True: Saves current schemes to profiled_patterns + 2. Creates a RegionPattern from the new region's structure + 3. For pattern-based: tries to seed schemes from pattern cache if available + 4. Sets as current for generate() and submit() calls + + Pass region=None to clear the current profile target without setting a new one. + + Args: + region: The region to profile next (None to clear current target) + commit: If True, commit current schemes to profiled_patterns + before switching. Set to False during initialization. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ + if commit or region is None: + self._commit_current_pattern(save=commit) + if region is None: + return + + if region not in self.regions: + raise ValueError(f"Region {region.id} not found in regions") + + region_pattern = RegionPattern.from_region(region, self.graph) + + if self._is_region_profiled(region): + logger.info(f"Skipping region {region.id} (pattern already profiled)") + logger.debug(f"Pattern signature: {region_pattern.signature}") + return + + pattern_schemes, num_seeded = self._seed_from_cache(region_pattern) + if pattern_schemes is None: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + logger.debug("Initialized with empty scheme collection") + + self.current_profile_region = region + self.current_profile_pattern_schemes = pattern_schemes + + mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" + logger.info( + f"Profiling region {region.id} [level {region.level}, size" + f"{region.get_size_of_region_and_descendants()}, {mode_info}]" + ) + logger.debug(f"Pattern signature: {region_pattern.signature}") + + @_requires_init + def generate(self) -> int: + """Generate a new Q/DQ insertion scheme for the current pattern or region. + + Creates a new InsertionScheme by mutating the top-performing schemes: + 1. Checks if there are any cached schemes (error=False, latency_ms=inf) + 2. If cached schemes exist, picks one to re-profile + 3. Otherwise, generates a new scheme by mutation + 4. Selects a random scheme from the top 10 performers + 5. Mutates it by adding/removing insertion points + 6. Ensures the new scheme is unique (different from existing schemes) + 7. Adds the scheme to current_profile_pattern_schemes + + """ + if self.current_profile_pattern_schemes is None: + raise InvalidSchemeError("No region selected. Call set_profile_region() first.") + + pattern_schemes = self.current_profile_pattern_schemes + cached_schemes = [ + (idx, scheme) + for idx, scheme in enumerate(pattern_schemes.schemes) + if not scheme.is_profiled + ] + + if cached_schemes: + scheme_index, cached_scheme_data = cached_schemes[0] + num_node_points = len(cached_scheme_data.node_inputs) + num_region_composite_points = len(cached_scheme_data.child_region_inputs) + num_region_output_points = len(cached_scheme_data.region_outputs) + total_points = num_node_points + num_region_composite_points + num_region_output_points + + logger.info( + f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Cached scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} + max_attempts = self.config.maximum_generation_attempts + + logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") + + for attempts in range(max_attempts): + new_scheme = self._generate_next_insertion_sample() + if new_scheme.hash not in known_schemes and not new_scheme.error: + pattern_schemes.schemes.append(new_scheme) + scheme_index = len(pattern_schemes.schemes) - 1 + num_node_points = len(new_scheme.node_inputs) + num_region_composite_points = len(new_scheme.child_region_inputs) + num_region_output_points = len(new_scheme.region_outputs) + total_points = ( + num_node_points + num_region_composite_points + num_region_output_points + ) + + logger.info( + f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points " + f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") + return -1 + + def _resolve_scheme_for_region( + self, region: Region, best: bool + ) -> tuple[InsertionScheme | None, RegionPattern]: + """Resolve the insertion scheme to use for a region from profiled/current/cache. + + Args: + region: The region to resolve the scheme for + best: If True, return the best scheme for the region + + Returns: + tuple[InsertionScheme | None, RegionPattern]: The scheme and pattern for the region + """ + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + pattern_schemes = self.current_profile_pattern_schemes + if pattern_schemes is None or pattern != pattern_schemes.pattern: + pass + elif best: + current_scheme = pattern_schemes.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + if scheme_index < 0 or scheme_index >= len(pattern_schemes.schemes): + raise IndexError( + f"Invalid scheme index: {scheme_index} " + f"(pattern has {len(pattern_schemes.schemes)} schemes)" + ) + current_scheme = pattern_schemes.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if cache_schemes is not None: + schemes = cache_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + + return current_scheme, pattern + + def _exclude_overlapping_insertion_points( + self, + resolved_insertion_points: set[ResolvedInsertionPoint], + region: Region, + pattern: RegionPattern, + ) -> None: + """Remove this region's full insertion points from resolved set so they can be replaced.""" + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + if full_insertion_scheme is None: + raise ValueError("get_full_insertion_scheme returned None") + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + for ip in all_region_ips: + node = self.graph.nodes[ip.node_index] + # Conv/ConvTranspose/Gemm/MatMul inputs and weights must be excluded together + if is_linear_op(node.op) and ip.input_index == 0 and len(node.inputs) >= 2: + resolved_insertion_points.discard(ip) + resolved_insertion_points.discard( + ResolvedInsertionPoint( + tensor_name=node.inputs[1].name, + node_index=ip.node_index, + input_index=1, + ) + ) + if not isinstance(all_region_ips, set): + raise TypeError( + f"pattern.matches must return a set, got {type(all_region_ips).__name__}" + ) + resolved_insertion_points.difference_update(all_region_ips) + if all_region_ips: + logger.debug(f" → Excluded {len(all_region_ips)} overlapping insertion points") + + @_requires_init + def export_onnx( + self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False + ) -> bytes: + """Export ONNX model with Q/DQ nodes inserted according to tested schemes. + + This method creates a modified version of the model by: + 1. For each region, finding the matching pattern + 2. Applying the best scheme for profiled patterns + 3. Applying the current scheme for the active profile pattern + 4. Resolving pattern-relative insertion points to actual tensor names + 5. Inserting Q/DQ pairs at the resolved locations + 6. Converting to FP8 if needed (always creates INT8 first, then converts) + + Args: + output_path: Optional file path where the modified ONNX model will be saved. + If None, the model is not saved to disk and only bytes are returned. + insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model + (useful for baseline measurements) + + Returns: + bytes: Serialized ONNX model as bytes + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ + output_desc = output_path if output_path is not None else "" + resolved_insertion_points = set() + + logger.debug( + f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " + f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" + ) + + if insert_qdq: + matched_regions = 0 + + logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") + + for region in self.regions: + current_scheme, pattern = self._resolve_scheme_for_region(region, best) + if current_scheme is None: + continue + + self._exclude_overlapping_insertion_points( + resolved_insertion_points, region, pattern + ) + + new_ips = pattern.matches(region, self.graph, current_scheme) + if new_ips: + resolved_insertion_points.update(new_ips) + matched_regions += 1 + logger.debug(f" → Added {len(new_ips)} insertion points") + + logger.debug( + f"Matched {matched_regions}/{len(self.regions)} regions, " + f"total {len(resolved_insertion_points)} unique insertion points" + ) + + unique_tensors = len(resolved_insertion_points) + + logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") + + original_quant_type = self.config.default_quant_type + needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" + + model = export_qdq_onnx( + self.onnx_model, + resolved_insertion_points, + self.config, + insert_qdq=insert_qdq and bool(resolved_insertion_points), + needs_fp8_conversion=needs_fp8_conversion, + ) + + model_bytes = model.SerializeToString() + quant_type_str = "baseline" + output_dest = "" + + if insert_qdq: + quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" + + if output_path is not None: + onnx.save(model, output_path) + output_dest = f" → {output_path}" + + logger.info( + f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs {output_dest}" + ) + return model_bytes + + @_requires_init + def submit(self, latency_ms: float, success: bool = True) -> None: + """Submit performance measurement for the most recently generated scheme. + + This method records the measured latency and manages the optimization state: + + Args: + latency_ms: Measured latency in milliseconds (must be > 0) + success: Whether the measurement succeeded. If False, sets scheme.error=True, + logs a warning, and skips speedup calculation. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + InvalidSchemeError: If no pattern or region is set, or no schemes have been generated + """ + if self.baseline_latency_ms is None: + self.baseline_latency_ms = latency_ms + logger.info(f"Baseline latency: {latency_ms:.3f} ms") + return + + if self.current_profile_pattern_schemes is None: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + schemes_collection = self.current_profile_pattern_schemes + if not schemes_collection.schemes: + raise InvalidSchemeError("No schemes available. Call generate() first.") + + pattern_schemes = schemes_collection + + if self.current_insertion_scheme_index is not None: + scheme_index = self.current_insertion_scheme_index + if scheme_index >= len(pattern_schemes.schemes): + raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") + scheme = pattern_schemes.schemes[scheme_index] + else: + scheme = pattern_schemes.schemes[-1] + scheme_index = len(pattern_schemes.schemes) - 1 + + scheme.latency_ms = latency_ms + scheme.error = not success + scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() + display_index = scheme_index + 1 + + if not success: + logger.warning( + f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" + ) + logger.debug("Marking scheme with error flag") + return + + speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 + + logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") + logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") + + old_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + pattern_schemes.schemes.sort( + key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") + ) + new_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + + if new_best < old_best: + new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 + logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") + logger.debug(f"Previous best: {old_best:.3f} ms") + + if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.debug( + f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def save_state(self, output_path: str) -> None: + """Save complete autotuner state to a YAML file for later reuse. + + Serializes all optimization results including: + - Baseline latency measurement + - All profiled patterns with their signatures + - All generated schemes with insertion points and latencies + - Configuration parameters + - Current profiling state + + Args: + output_path: File path where the YAML state file will be written. + Pattern cache will be saved to _pattern_cache.yaml + """ + current_pattern_sig = None + if self.current_profile_pattern_schemes is not None: + current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature + + state = { + "baseline_latency_ms": self.baseline_latency_ms, + "current_profile_pattern_schemes_signature": current_pattern_sig, + "config": dataclasses.asdict(self.config), + "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], + } + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + num_patterns = len(self.profiled_patterns) + total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) + + logger.info( + f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" + ) + logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") + + if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: + base_path, ext = os.path.splitext(output_path) + cache_path = f"{base_path}_pattern_cache{ext}" + self.pattern_cache.save(cache_path) + + logger.info(f"Saved pattern cache → {cache_path}") + logger.debug( + f"Cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + @_requires_init + def load_state(self, input_path: str) -> None: + """Load autotuner state from a previously saved YAML file. + + Restores optimization results from a previous session: + 1. Matches saved patterns to current model's patterns by signature + 2. Loads all schemes with their insertion points and latencies (including unmeasured ones) + 3. Restores baseline latency and configuration + + Args: + input_path: File path to the YAML state file to load + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + FileNotFoundError: If the input_path doesn't exist + """ + with open(input_path) as f: + state = yaml.safe_load(f) + + if state.get("baseline_latency_ms") is not None: + self.baseline_latency_ms = state["baseline_latency_ms"] + logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") + + if "config" in state: + config_data = state["config"] + if isinstance(config_data, dict): + default_dict = dataclasses.asdict(Config()) + default_dict.update({k: v for k, v in config_data.items() if k in default_dict}) + self.config = Config(**default_dict) + logger.debug(f"Config restored: quant_type={self.config.default_quant_type}") + + if "patterns" in state: + num_loaded_patterns = 0 + num_loaded_schemes = 0 + + for pattern_data in state["patterns"]: + try: + pattern_schemes = PatternSchemes.from_dict(pattern_data) + + if pattern_schemes.schemes: + self.profiled_patterns.append(pattern_schemes) + num_loaded_patterns += 1 + num_loaded_schemes += len(pattern_schemes.schemes) + else: + logger.debug( + f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." + ) + + except (KeyError, TypeError, ValueError) as exc: # noqa: PERF203 + logger.warning("Failed to load pattern: %s", exc) + continue + + logger.info( + f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " + f"{num_loaded_schemes} schemes)" + ) + + base_path, ext = os.path.splitext(input_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + if os.path.exists(cache_path): + try: + loaded_cache = PatternCache.load(cache_path) + + if self.pattern_cache is not None: + for pattern_schemes in loaded_cache.pattern_schemes: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + else: + self.pattern_cache = loaded_cache + logger.info( + f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " + f"{loaded_cache.total_schemes} schemes)" + ) + except (OSError, yaml.YAMLError, KeyError, TypeError, ValueError) as exc: + logger.warning("Failed to load pattern cache: %s", exc) + else: + logger.debug(f"No pattern cache file at {cache_path}") + + @_requires_init + def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. + + Analyzes the current model's regions against the provided quantized tensors + to extract Q/DQ insertion patterns. For each region, creates a pattern cache + entry that captures which insertion points correspond to the quantized tensors. + These cached patterns can then be used as seeds for future autotuning sessions. + + Args: + quantized_tensors: Set or list of tensor names that are quantized + (i.e., tensors that have Q/DQ nodes applied to them) + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + """ + if isinstance(quantized_tensors, list): + quantized_tensors = set(quantized_tensors) + + logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") + logger.debug(f"Processing {len(self.regions)} regions") + + if self.pattern_cache is None: + logger.warning("Pattern cache not initialized, skipping import") + return + + patterns_before = self.pattern_cache.num_patterns + schemes_before = self.pattern_cache.total_schemes + + for region in self.regions: + self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) + + patterns_added = self.pattern_cache.num_patterns - patterns_before + schemes_added = self.pattern_cache.total_schemes - schemes_before + + logger.info( + f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" + ) + logger.debug( + f"Total cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def _compute_convergence_metrics( + self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None + ) -> tuple[int | None, float | None]: + """Compute convergence metrics for a collection of schemes. + + Analyzes when the best scheme was discovered during the profiling process + by sorting schemes by their profile timestamps and finding the position + of the best scheme. + + Args: + schemes: List of insertion schemes with profile timestamps + best_scheme: The best performing scheme (lowest latency) + + Returns: + Tuple of (samples_before_best, time_to_best) where: + - samples_before_best: Number of samples tested before finding best (0-based index) + - time_to_best: Time in seconds from first sample to best sample + Both values are None if metrics cannot be computed (e.g., missing timestamps) + """ + samples_before_best = None + time_to_best = None + + if not best_scheme or not best_scheme.profile_timestamp: + return samples_before_best, time_to_best + + schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] + + if not schemes_with_time: + return samples_before_best, time_to_best + + schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") + + try: + best_position = next( + i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash + ) + samples_before_best = best_position + + first_ts = schemes_with_time[0].profile_timestamp + best_ts = best_scheme.profile_timestamp + if first_ts is not None and best_ts is not None: + first_timestamp = datetime.fromisoformat(first_ts) + best_timestamp = datetime.fromisoformat(best_ts) + time_to_best = (best_timestamp - first_timestamp).total_seconds() + except (StopIteration, ValueError): + pass + + return samples_before_best, time_to_best + + def _is_region_profiled(self, region: Region) -> bool: + """Check if a region's pattern has already been fully profiled.""" + return any( + p.pattern is not None + and p.pattern.matches(region, self.graph) + and all(s.is_profiled for s in p.schemes) + for p in self.profiled_patterns + ) + + def _mutate_insertion_points( + self, base_points, all_points, point_type: str, max_mutations: int + ) -> list: + """Mutate a set of insertion points by adding, removing, or both.""" + key_fn = { + "node input points": lambda p: (p.node_index, p.input_index), + "region composite points": lambda p: (p.region_index, p.input_index), + "region output points": lambda p: (p.region_index, p.node_index, p.output_index), + }.get(point_type) + + if not key_fn: + return [] + + current_points = set(base_points) + initial_count = len(current_points) + mutation_type = random.choice(["add", "remove", "both"]) + + if mutation_type in ["add", "both"] and len(current_points) < len(all_points): + all_keys = {key_fn(p) for p in all_points} + available_keys = all_keys - current_points + if available_keys: + max_add = min(max_mutations, len(available_keys)) + num_to_add = random.randint(1, max_add) + to_add = random.sample(list(available_keys), num_to_add) + current_points.update(to_add) + + if mutation_type in ["remove", "both"] and current_points: + max_remove = min(max_mutations, len(current_points)) + num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 + num_to_remove = min(num_to_remove, len(current_points)) + to_remove = random.sample(list(current_points), num_to_remove) + for p in to_remove: + current_points.discard(p) + + logger.debug( + f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" + ) + + return [p for p in all_points if key_fn(p) in current_points] + + def _generate_next_insertion_sample(self) -> InsertionScheme: + """Generate a new insertion scheme by mutating top performers. + + This is the core scheme generation algorithm: + 1. Identifies top schemes by latency + 2. Randomly selects one as the base + 3. Mutates node input insertion points (add, remove, or both) + 4. Mutates region composite insertion points (child boundaries) + 5. Mutates region output insertion points + 6. Returns new unique scheme + + **Mutation Strategy:** + - Node input points: Add/remove 1-3 insertion points + - Region composite points: Add/remove 1-3 boundary points + - Region output points: Add/remove 1-3 output points + - Mutation type chosen randomly: 'add', 'remove', or 'both' + + **Baseline Case:** + If no schemes exist yet, returns an empty baseline scheme. + + Returns: + New InsertionScheme with mutated insertion points. + Returns empty scheme if no region is set or no candidates exist. + """ + if self.current_profile_region is None: + return InsertionScheme() + + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + return InsertionScheme() + + region = self.current_profile_region + pattern_schemes = schemes_collection + + if not isinstance(schemes_collection, PatternSchemes) or schemes_collection.pattern is None: + return InsertionScheme() + pattern = schemes_collection.pattern + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + + logger.debug( + f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " + f"{len(full_insertion_scheme.child_region_inputs)} region composite, " + f"{len(full_insertion_scheme.region_outputs)} region output" + ) + + top_percent = self.config.top_percent_to_mutate + minimum_schemes = self.config.minimum_schemes_to_mutate + + measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] + measured_schemes.sort(key=lambda s: s.latency_ms) + + num_top_schemes = max( + int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) + ) + top_schemes = measured_schemes[:num_top_schemes] + + if len(top_schemes) == 0: + logger.debug("No measured schemes yet, generating baseline (empty) scheme") + return InsertionScheme() + + base_scheme = random.choice(top_schemes) + total_base_points = ( + len(base_scheme.node_inputs) + + len(base_scheme.child_region_inputs) + + len(base_scheme.region_outputs) + ) + logger.debug( + f"Mutating from top {len(top_schemes)} schemes: " + f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" + ) + + max_mutations = self.config.maximum_mutations + scheme = InsertionScheme() + + for attr, point_type, key_fn in _MUTATION_SPECS: + base_points = {key_fn(p) for p in getattr(base_scheme, attr)} + setattr( + scheme, + attr, + self._mutate_insertion_points( + base_points, + getattr(full_insertion_scheme, attr), + point_type, + max_mutations, + ), + ) + + return scheme + + def _copy_graph(self) -> gs.Graph: + """Create an independent copy of the computation graph.""" + new_graph = gs.import_onnx(self.onnx_model) + new_graph.toposort() + return new_graph diff --git a/modelopt/onnx/quantization/autotune/export_utils.py b/modelopt/onnx/quantization/autotune/export_utils.py new file mode 100644 index 000000000..b3eb1bbba --- /dev/null +++ b/modelopt/onnx/quantization/autotune/export_utils.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Q/DQ model export and insertion in ONNX autotune.""" + +import dataclasses + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Config +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + merge_resolved_insertion_points, +) +from modelopt.onnx.quantization.fp8 import int8_to_fp8 +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +__all__ = [ + "build_tensor_map", + "create_qdq_nodes", + "export_qdq_onnx", + "fix_zero_point_initializers", + "get_tensor_metadata", + "get_zero_point_for_quant_type", + "insert_qdq_at_tensors", + "resolve_dtype", +] + +_DTYPE_MAP = { + "int8": np.int8, + "uint8": np.uint8, + "float16": np.float16, + "float32": np.float32, +} + + +def resolve_dtype( + dtype_str: str, default: np.dtype = np.int8, dtype_map: dict | None = None +) -> np.dtype: + """Resolve a dtype string (quant or DQ output) to a numpy dtype.""" + if dtype_map is None: + dtype_map = _DTYPE_MAP + if dtype_str == "fp8": + try: + return np.dtype(np.float8_e4m3fn) + except (AttributeError, TypeError): + logger.warning( + "FP8 dtype not available (requires numpy >= 2.0), " + "using uint8 as placeholder. Note: This may not produce " + "correct results without proper FP8 support." + ) + return np.uint8 + if hasattr(np, "bfloat16") and dtype_str == "bfloat16": + return np.bfloat16 + if dtype_str in dtype_map: + return dtype_map[dtype_str] + logger.warning(f"Unknown dtype '{dtype_str}', using default {default}") + return default + + +def get_zero_point_for_quant_type(quant_type: str, quant_dtype: np.dtype) -> int: + """Return default zero point for quant type and validate it is in-range for the dtype. + + int8 uses 0 (signed); uint8 uses 128 (unsigned midpoint); fp8/other use 0. + Raises ValueError if the default zero point is not in the valid range for quant_dtype. + """ + default_zp = 128 if quant_type == "uint8" else 0 + if quant_dtype == np.int8: + low, high = -128, 127 + if not (low <= default_zp <= high): + raise ValueError( + f"Zero point {default_zp} out of range for int8 (must be in [{low}, {high}])" + ) + elif quant_dtype == np.uint8: + low, high = 0, 255 + if not (low <= default_zp <= high): + raise ValueError( + f"Zero point {default_zp} out of range for uint8 (must be in [{low}, {high}])" + ) + return default_zp + + +def build_tensor_map(graph: gs.Graph) -> dict[str, gs.Tensor]: + """Build mapping from tensor names to tensor objects.""" + tensor_map = {t.name: t for t in graph.inputs if hasattr(t, "name") and t.name} + for node in graph.nodes: + for t in node.inputs: + if hasattr(t, "name") and t.name: + tensor_map[t.name] = t + for t in node.outputs: + if isinstance(t, gs.Constant) and hasattr(t, "name") and t.name: + tensor_map[t.name] = t + return tensor_map + + +def get_tensor_metadata( + tensor: gs.Tensor, is_constant: bool, default_dtype: np.dtype +) -> tuple[tuple | None, np.dtype]: + """Extract shape and dtype metadata from a tensor.""" + if is_constant and hasattr(tensor, "values") and tensor.values is not None: + return tensor.values.shape, tensor.values.dtype + if hasattr(tensor, "shape"): + dtype = ( + tensor.dtype if hasattr(tensor, "dtype") and tensor.dtype is not None else default_dtype + ) + return tensor.shape, dtype + return None, default_dtype + + +def fix_zero_point_initializers(model: onnx.ModelProto) -> None: + """Fix INT8 zero_point initializers to use int32_data instead of raw_data.""" + fixed_count = 0 + for initializer in model.graph.initializer: + if ( + "_zp_" in initializer.name + and initializer.data_type == onnx.TensorProto.INT8 + and len(initializer.raw_data) > 0 + and len(initializer.int32_data) == 0 + ): + np_array = onnx.numpy_helper.to_array(initializer) + int32_values = np_array.astype(np.int32).flatten().tolist() + new_tensor = onnx.helper.make_tensor( + initializer.name, + onnx.TensorProto.INT8, + list(initializer.dims), + int32_values, + ) + initializer.CopyFrom(new_tensor) + fixed_count += 1 + if fixed_count > 0: + logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") + + +def create_qdq_nodes( + tensor_name: str, + qdq_input: gs.Tensor, + output_shape: tuple | None, + output_dtype: np.dtype, + quant_dtype: np.dtype, + q_scale: float, + q_zero_point: int, +) -> tuple[gs.Node, gs.Node]: + """Create QuantizeLinear and DequantizeLinear node pair.""" + q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") + dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") + dtype_map = {"float16": np.float16, "float32": np.float32} + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + scale_dtype = dtype_map.get(np.dtype(output_dtype).name, np.float32) + + logger.debug( + f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" + ) + + q_scale_values = np.array([q_scale], dtype=scale_dtype) + q_zp_values = np.array([q_zero_point], dtype=quant_dtype) + q_inputs = [ + qdq_input, + gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), + gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), + ] + q_node = gs.Node( + op="QuantizeLinear", + name=q_name, + inputs=q_inputs, + outputs=[gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape)], + ) + + dq_scale_values = np.array([q_scale], dtype=scale_dtype) + dq_zp_values = np.array([q_zero_point], dtype=quant_dtype) + dq_inputs = [ + q_node.outputs[0], + gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), + gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), + ] + dq_node = gs.Node( + op="DequantizeLinear", + name=dq_name, + inputs=dq_inputs, + outputs=[gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape)], + ) + return q_node, dq_node + + +def insert_qdq_at_tensors( + graph: gs.Graph, + resolved_insertion_points: set[ResolvedInsertionPoint], + config: Config, + *, + tensor_users_map: dict[str, list[int]] | None = None, +) -> None: + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations. + + Modifies the graph in-place. Builds tensor map and tensor-to-users map, + processes each resolved insertion point, and runs graph cleanup/toposort. + + Args: + graph: Graph to modify in-place. + resolved_insertion_points: Set of ResolvedInsertionPoint specifying where to insert Q/DQ. + config: Config with default_q_scale, default_q_zero_point, default_quant_type, default_dq_dtype. + tensor_users_map: Optional precomputed tensor name -> list of node indices. If None, computed. + """ + q_scale = config.default_q_scale + q_zero_point = config.default_q_zero_point + quant_type = config.default_quant_type + quant_dtype = resolve_dtype(quant_type, np.int8, _DTYPE_MAP) + + logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point={q_zero_point}") + + resolved_insertion_points = merge_resolved_insertion_points(graph, resolved_insertion_points) + + tensor_map = build_tensor_map(graph) + if tensor_users_map is None: + tensor_users_map = get_tensor_consumer_node_indices(graph) + logger.debug( + f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" + ) + + default_dq_dtype = resolve_dtype(config.default_dq_dtype, np.float32, _DTYPE_MAP) + + for insertion_point in resolved_insertion_points: + tensor_name = insertion_point.tensor_name + node_index = insertion_point.node_index + input_index = insertion_point.input_index + + original_tensor = tensor_map[tensor_name] + if node_index is not None: + if node_index < 0 or node_index >= len(graph.nodes): + raise IndexError( + f"Node index out of range: {node_index} (graph has {len(graph.nodes)} nodes)" + ) + target_node = graph.nodes[node_index] + if input_index is None: + raise ValueError("Input index must be set when node index is set") + if input_index < 0 or input_index >= len(target_node.inputs): + raise IndexError( + f"Input index out of range for node {target_node.name}: " + f"{input_index} (node has {len(target_node.inputs)} inputs)" + ) + original_tensor = target_node.inputs[input_index] + if tensor_name != original_tensor.name: + raise ValueError( + f"Tensor name mismatch for node {target_node.name} input {input_index}: " + f"expected {tensor_name!r}, got {original_tensor.name!r}" + ) + else: + if tensor_name not in tensor_map: + raise KeyError(f"Tensor {tensor_name!r} not found in tensor map") + if input_index is not None: + raise ValueError("Input index must be None when node index is None") + + is_constant = isinstance(original_tensor, gs.Constant) + output_shape, output_dtype = get_tensor_metadata( + original_tensor, is_constant, default_dtype=default_dq_dtype + ) + + unique_suffix = "qdq" + if node_index is not None: + unique_suffix = f"n{node_index}_i{input_index}" + unique_tensor_name = f"{tensor_name}_{unique_suffix}" + + q_node, dq_node = create_qdq_nodes( + unique_tensor_name, + original_tensor, + output_shape, + output_dtype, + quant_dtype, + q_scale, + q_zero_point, + ) + + graph.nodes.extend([q_node, dq_node]) + + if node_index is not None: + target_node.inputs[input_index] = dq_node.outputs[0] + logger.debug( + f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " + f"({target_node.name}) input #{input_index}" + ) + else: + users = tensor_users_map[tensor_name] + for user_index in users: + user_node = graph.nodes[user_index] + for i, input_tensor in enumerate(user_node.inputs): + if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: + user_node.inputs[i] = dq_node.outputs[0] + break + logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") + + logger.debug("Running graph cleanup and topological sort") + try: + graph.cleanup().toposort() + logger.debug("Graph cleanup completed") + except (ValueError, RuntimeError) as exc: + logger.error("Graph cleanup failed: %s", exc) + raise RuntimeError(f"Graph cleanup failed after Q/DQ insertion: {exc}") from exc + + +def export_qdq_onnx( + source: onnx.ModelProto | gs.Graph, + resolved_insertion_points: set[ResolvedInsertionPoint], + config: Config, + *, + insert_qdq: bool = True, + needs_fp8_conversion: bool = False, +) -> onnx.ModelProto: + """Export ONNX model with optional Q/DQ insertion and optional INT8→FP8 conversion. + + Does not modify the source; works on a copy of the graph. + + Args: + source: ONNX model or GraphSurgeon graph to export from. + resolved_insertion_points: Set of insertion points (used when insert_qdq is True). + config: Config for Q/DQ parameters and dtypes. + insert_qdq: If True, insert Q/DQ at resolved points before exporting. + needs_fp8_conversion: If True, build as INT8 then convert to FP8 (e.g. when config.default_quant_type is fp8). + + Returns: + Exported ONNX ModelProto (with Q/DQ and/or FP8 as requested). + """ + if isinstance(source, onnx.ModelProto): + graph_copy = gs.import_onnx(source) + else: + graph_copy = gs.import_onnx(gs.export_onnx(source)) + graph_copy.toposort() + + if insert_qdq and resolved_insertion_points: + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + config_int8 = dataclasses.replace(config, default_quant_type="int8") + insert_qdq_at_tensors(graph_copy, resolved_insertion_points, config_int8) + else: + insert_qdq_at_tensors(graph_copy, resolved_insertion_points, config) + + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) + + if insert_qdq and resolved_insertion_points: + fix_zero_point_initializers(model) + + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + + return model From b8235ad7cc44f6874ece347323b8d59ba70263c6 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Thu, 26 Feb 2026 07:14:00 +0000 Subject: [PATCH 13/14] fix pre-commit failure Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py index a722cabbb..7f14bb360 100644 --- a/modelopt/onnx/quantization/autotune/__init__.py +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -21,8 +21,8 @@ """ # Core data structures -from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark from .autotuner import QDQAutotuner +from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark from .common import ( AutotunerError, AutotunerNotInitializedError, From c1f93b269992cbb99a6a754218713f8f2d2e1e48 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Thu, 26 Feb 2026 17:56:06 +0000 Subject: [PATCH 14/14] fix tox doc build failures Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/common.py | 56 ++++++------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index 922ab09eb..01fa4aaf4 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -795,44 +795,24 @@ class Config: Controls the autotuning process including performance requirements, quantization parameters, region building, scheme generation, and finetuning behavior. - Attributes: - # Logging - verbose: Enable detailed logging of autotuning progress (default: False) - - # Performance Requirements - performance_threshold: Minimum speedup ratio to accept a scheme. - 1.0 = no improvement required, 1.02 = 2% improvement (default: 1.02) - - # Quantization Parameters - default_q_scale: Default scale parameter for Q/DQ nodes. Controls quantization - granularity. Typical range: 0.01-0.1 (default: 0.1) - default_q_zero_point: Default zero-point for Q/DQ nodes. Use 0 for signed int8, - 128 for unsigned uint8 (default: 0) - default_quant_type: Quantization type for Q/DQ nodes. Options: "int8" (default), "fp8" - - # Region Builder Settings - maximum_sequence_region_size: Maximum number of nodes in a sequence region during - top-down refinement. Prevents overly large merged regions (default: 10) - minimum_topdown_search_size: Minimum number of nodes in a region to trigger - top-down search during region building (default: 10) - - # Scheme Generation Settings - top_percent_to_mutate: Top percentage of best schemes to use as mutation seeds - during scheme generation. Range: 0.0-1.0 (default: 0.1 = top 10%) - minimum_schemes_to_mutate: Minimum number of schemes to keep as mutation seeds, - even if top_percent_to_mutate results in fewer (default: 10) - maximum_mutations: Maximum number of mutations to apply to a single scheme - during generation (default: 3) - maximum_generation_attempts: Maximum attempts to generate a unique new scheme - before giving up (default: 100) - - # Pattern Cache Settings - pattern_cache_minimum_distance: Minimum edit distance required between schemes in cache. - When adding schemes, if a scheme is too similar (distance < minimum_distance) - to an existing scheme, only the better-performing one is kept (default: 4) - pattern_cache_max_entries_per_pattern: Maximum number of schemes to keep per pattern - in pattern cache. Only the top N best-performing schemes are kept for each pattern. - Use 0 to keep all schemes (default: 32) + Attributes are documented below as a list to avoid duplicate index entries with + autodoc-generated attribute docs. Key fields: + + - verbose: Enable detailed logging of autotuning progress (default: False). + - performance_threshold: Minimum speedup ratio to accept a scheme; + 1.0 = no improvement required, 1.02 = 2% improvement (default: 1.02). + - default_q_scale: Default scale for Q/DQ nodes; typical range 0.01-0.1 (default: 0.1). + - default_q_zero_point: Zero-point for Q/DQ; 0 for int8, 128 for uint8 (default: 0). + - default_quant_type: Quantization type; "int8" (default) or "fp8". + - default_dq_dtype: Dtype for DequantizeLinear output; "float32" (default) or "float16". + - maximum_sequence_region_size: Max nodes in a sequence region (default: 10). + - minimum_topdown_search_size: Min nodes to trigger top-down search (default: 10). + - top_percent_to_mutate: Top fraction of schemes used as mutation seeds (default: 0.1). + - minimum_schemes_to_mutate: Min schemes to keep as mutation seeds (default: 10). + - maximum_mutations: Max mutations per scheme during generation (default: 3). + - maximum_generation_attempts: Max attempts to generate a unique scheme (default: 100). + - pattern_cache_minimum_distance: Min edit distance between cached schemes (default: 4). + - pattern_cache_max_entries_per_pattern: Max schemes per pattern in cache (default: 32). """ # Logging