diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9d7500e58..bbbe6ab9e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow. - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. +- Add PTQ support for Nemotron Parse. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..71755a02f 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -31,6 +31,7 @@ from safetensors.torch import load_file from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, @@ -75,19 +76,18 @@ def run_nemotron_vl_preview( "eos_token_id": tokenizer.eos_token_id, } - # Try text-only generation + # Try text-only generation (may fail for encoder-decoder models like Nemotron-Parse) text_response = run_text_only_generation( full_model, tokenizer, question, generation_config, pyt_ckpt_path ) + generated_ids = None if text_response is not None: print(f"✅ Text-only generation successful: {text_response[:100]}...") generated_ids = text_response elif allow_fallback: print("Text-only generation failed, falling back to standard generate...") generated_ids = full_model.generate(input_ids, max_new_tokens=100) - else: - generated_ids = None # Run additional VL test with images print(f"Running additional VL test with images ({stage_name})...") @@ -106,6 +106,10 @@ def _is_multimodal_config(config): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or getattr(config, "is_encoder_decoder", False) # Encoder-decoder VL models + or any( # Architecture-based detection for custom VL models (e.g., Nemotron-Parse) + "conditionalgeneration" in arch.lower() for arch in getattr(config, "architectures", []) + ) ) @@ -158,9 +162,20 @@ def calibrate_loop(_model): ) allowed_keys = set(forward_params.keys()) + # Check if model is encoder-decoder (needs decoder_input_ids instead of input_ids) + is_enc_dec = getattr(full_model.config, "is_encoder_decoder", False) + full_model.eval() with torch.no_grad(): for batch in calib_dataloader: + # For encoder-decoder models, rename input_ids → decoder_input_ids + # and disable KV caching to avoid tuple index errors in decoder layers + if is_enc_dec and "input_ids" in batch and "pixel_values" in batch: + batch["decoder_input_ids"] = batch.pop("input_ids") + if "attention_mask" in batch: + batch["decoder_attention_mask"] = batch.pop("attention_mask") + batch["use_cache"] = False + # Filter batch to only include parameters the model accepts if accepts_kwargs: call_kwargs = batch @@ -172,10 +187,8 @@ def calibrate_loop(_model): # Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style) # For other VLMs (like Nemotron-Parse), use standard forward if hasattr(full_model, "img_context_token_id"): - # Nemotron Nano VL style safe_nemotron_vl_forward(full_model, call_kwargs) else: - # Standard encoder-decoder or other VLM architectures full_model(**call_kwargs) return calibrate_loop @@ -312,8 +325,15 @@ def get_processor( ) return MllamaImageProcessor(processor, device) - - return None + else: + # Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse) + try: + processor = AutoProcessor.from_pretrained(ckpt_path, **model_kwargs) + print(f"Loaded AutoProcessor for model type: {model_type}") + return processor + except Exception as e: + print(f"Could not load processor for {model_type}: {e}") + return None def load_mtp_weights( @@ -447,6 +467,7 @@ def get_model( # Load config once and handle VL model detection try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " @@ -466,8 +487,6 @@ def get_model( model_kwargs.setdefault("torch_dtype", "auto") if "vila" in ckpt_path.lower(): - from transformers import AutoModel - hf_vila = AutoModel.from_pretrained( ckpt_path, device_map=device_map, @@ -510,13 +529,17 @@ def get_model( if not hasattr(transformers, architecture): warnings.warn( f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." + "Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)." ) assert trust_remote_code, ( "Please set trust_remote_code to True if you want to use this architecture" ) - auto_model_module = AutoModelForCausalLM + # Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models + if getattr(hf_config, "is_encoder_decoder", False): + auto_model_module = AutoModel + else: + auto_model_module = AutoModelForCausalLM from_config = auto_model_module.from_config else: auto_model_module = getattr(transformers, architecture) @@ -527,7 +550,7 @@ def get_model( # unless specified by the hf_config. torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: + if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) model_kwargs2["torch_dtype"] = torch_dtype model_kwargs2.pop("max_memory", None) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d9a6ca893..de434e1cf 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -361,6 +361,12 @@ def load_model(args: argparse.Namespace): default_pad_token = None is_nemotron_vl_model = is_nemotron_vl(full_model) + + # Default to image-text calibration for VLM models + if is_nemotron_vl_model and not args.calib_with_images: + print("Nemotron VL model detected. Enabling image-text calibration by default.") + args.calib_with_images = True + if model_type == "mllama": processor = get_processor( args.pyt_ckpt_path, @@ -499,9 +505,12 @@ def mono_quantize( print("Disabling quantization for vision components in Nemotron VL model") quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - # Also disable radio model components specifically + # Also disable radio model components specifically (for Nemotron-Parse) quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder + quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific + print("Quantization will only be applied to the decoder (text generation) component") if not model_is_already_quantized or calibration_only: if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": @@ -686,7 +695,7 @@ def pre_quantize( preview_input_ids, args.pyt_ckpt_path, "before quantization", - allow_fallback=True, + allow_fallback=False, ) else: # Standard generation for non-Nemotron VL models @@ -800,36 +809,42 @@ def quantize_main( device: torch.device, ): if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = language_model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( - language_model.device - ) - * 100 - ) + # For VL models with image-text calibration, skip automatic batch size detection + # since get_max_batch_size can't handle multimodal inputs + if args.calib_with_images: + print("Image-text calibration enabled. Using default batch_size=1 for calibration.") + args.batch_size = 1 else: - sample_input_single_batch = None + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None - args.batch_size = get_max_batch_size( - language_model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) print(f"Use calib batch_size {args.batch_size}") diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py index 6c9d921b8..9919e405b 100644 --- a/examples/llm_ptq/vlm_utils.py +++ b/examples/llm_ptq/vlm_utils.py @@ -105,27 +105,31 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): else: processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - messages = [ - {"role": "system", "content": "/no_think"}, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "", - }, - { - "type": "text", - "text": question, - }, - ], - }, - ] - - # Apply chat template - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + # Use chat template if available, otherwise fall back to default task prompt + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "", + }, + { + "type": "text", + "text": question, + }, + ], + }, + ] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # For models without chat templates (e.g., encoder-decoder VL models), + # use the tokenizer's bos/eos tokens as a minimal prompt + prompt = (tokenizer.bos_token or "") + question # Process inputs using the processor with single image inputs = processor( @@ -139,6 +143,12 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): inputs = inputs.to(model_device) print(f" Moved inputs to {model_device}") + # Verify we have pixel_values for the vision encoder + if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None: + raise ValueError( + "Processor did not generate pixel_values. Check processor configuration." + ) + # Generate response using model.generate generated_ids = model.generate( pixel_values=inputs.pixel_values, @@ -148,12 +158,23 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): ) # Decode the response (trim input tokens like in the working example) + if generated_ids is None: + raise ValueError("Model generate returned None") + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + # Use processor.batch_decode if available, otherwise fall back to tokenizer + decoder = processor if hasattr(processor, "batch_decode") else tokenizer + output_text = decoder.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) + + if output_text is None or len(output_text) == 0: + raise ValueError("Decoding returned empty output") + response = output_text[0] print(f"✅ VL generation {stage_name} successful!") diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad..6cb5be9a5 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -85,6 +85,7 @@ def is_multimodal_model(model): - Vision LoRA configurations - Audio processing capabilities - Image embedding layers + - Nemotron-Parse conditional generation models Args: model: The HuggingFace model instance to check @@ -103,6 +104,10 @@ def is_multimodal_model(model): """ config = model.config + # Check for Nemotron-Parse encoder-decoder architecture + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + return ( hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) @@ -112,6 +117,7 @@ def is_multimodal_model(model): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or is_nemotron_parse # Nemotron-Parse conditional generation model ) @@ -141,5 +147,11 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None: if hasattr(model, "language_model"): return [model, model.language_model] - # Pattern 3: No language_model found + # Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model. + # Only match if the model is detected as multimodal to avoid matching non-VLM encoder-decoder + # models like T5, Bart, Whisper which also have .decoder. + if hasattr(model, "decoder") and is_multimodal_model(model): + return [model, model.decoder] + + # Pattern 4: No language_model found return None diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..b6b92f6ff 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -316,27 +316,27 @@ def llm_dummy_forward(): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part + if is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, run optimization on just the language model/decoder. + # This avoids needing pixel_values for the vision encoder. language_model_lineage = get_language_model_from_vl(model) if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use language_model = language_model_lineage[-1] print( f"Running optimization on language model with fake_input shape: {fake_input.shape}" ) - language_model(fake_input) + # Pass use_cache=False to avoid KV cache issues in encoder-decoder models + language_model(fake_input, use_cache=False) else: raise ValueError( f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " "This is required for requantization/resmoothing optimization. " "Please ensure the model architecture is supported or file an issue." ) + elif getattr(model.config, "is_encoder_decoder", False): + # For other encoder-decoder models (non-VL), pass both encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) else: model(fake_input)