Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions modelopt/torch/export/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Code that export quantized Hugging Face models for deployment."""

import json
import warnings
from collections.abc import Callable
from contextlib import contextmanager
Expand All @@ -23,6 +24,7 @@

import torch
import torch.nn as nn
from safetensors.torch import load_file, safe_open

from .layer_utils import is_quantlinear

Expand Down Expand Up @@ -656,3 +658,146 @@ def infer_dtype_from_model(model: nn.Module) -> torch.dtype:
for param in model.parameters():
return param.dtype
return torch.float16


def _merge_ltx2(
diffusion_transformer_state_dict: dict[str, torch.Tensor],
merged_base_safetensor_path: str,
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
"""Merge LTX-2 transformer weights with non-transformer components.

Non-transformer components (VAE, vocoder, text encoders) and embeddings
connectors are taken from the base checkpoint. Transformer keys are
re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ComfyUI require this merge in general, or it's ltx-2 specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoints released for LTX-2 on Hugging Face are full, already-merged models, as described in the official model card https://huggingface.co/Lightricks/LTX-2/blob/main/README.md. In ComfyUI, user can either load full model or work by loading and wiring up each model component or block separately.


Args:
diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
merged_base_safetensor_path: Path to the full base model safetensors file containing
all components (transformer, VAE, vocoder, etc.).

Returns:
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
safetensors metadata from the base checkpoint.
"""
base_state = load_file(merged_base_safetensor_path)

non_transformer_prefixes = [
"vae.",
"audio_vae.",
"vocoder.",
"text_embedding_projection.",
"text_encoders.",
"first_stage_model.",
"cond_stage_model.",
"conditioner.",
]
correct_prefix = "model.diffusion_model."
strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."]

base_non_transformer = {
k: v
for k, v in base_state.items()
if any(k.startswith(p) for p in non_transformer_prefixes)
}
base_connectors = {
k: v
for k, v in base_state.items()
if "embeddings_connector" in k and k.startswith(correct_prefix)
}

prefixed = {}
for k, v in diffusion_transformer_state_dict.items():
clean_k = k
for prefix in strip_prefixes:
if clean_k.startswith(prefix):
clean_k = clean_k[len(prefix) :]
break
prefixed[f"{correct_prefix}{clean_k}"] = v

merged = dict(base_non_transformer)
merged.update(base_connectors)
merged.update(prefixed)
with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
base_metadata = f.metadata() or {}

del base_state
return merged, base_metadata


DIFFUSION_MERGE_FUNCTIONS: dict[str, Callable] = {
"ltx2": _merge_ltx2,
}


def merge_diffusion_checkpoint(
state_dict: dict[str, torch.Tensor],
merged_base_safetensor_path: str,
model_type: str,
hf_quant_config: dict | None = None,
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
"""Merge transformer weights with a base checkpoint and build ComfyUI metadata.

Dispatches to the model-specific merge function in ``DIFFUSION_MERGE_FUNCTIONS``
and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` and
per-layer ``_quantization_metadata`` in the safetensors metadata for ComfyUI.

Args:
state_dict: The transformer state dict (already on CPU).
merged_base_safetensor_path: Path to the full base model ``.safetensors`` file
containing all components (transformer, VAE, vocoder, etc.),
e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
hf_quant_config: If provided, embed quantization config and per-layer
``_quantization_metadata`` in the returned metadata dict.

Returns:
Tuple of (merged_state_dict, metadata) where *metadata* is the base checkpoint's
original metadata augmented with any quantization entries.
"""
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
merged_state_dict, metadata = merge_fn(state_dict, merged_base_safetensor_path)

if hf_quant_config is not None:
metadata["quantization_config"] = json.dumps(hf_quant_config)

quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in merged_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)

return merged_state_dict, metadata


def get_diffusion_model_type(pipe: Any) -> str:
"""Detect the diffusion model type for merge function dispatch.

To add a new model type, add a detection clause here and a corresponding
merge function in ``DIFFUSION_MERGE_FUNCTIONS``.

Args:
pipe: The pipeline or component being exported.

Returns:
A string key into ``DIFFUSION_MERGE_FUNCTIONS``.

Raises:
ValueError: If the model type is not supported.
"""
if TI2VidTwoStagesPipeline is not None and isinstance(pipe, TI2VidTwoStagesPipeline):
return "ltx2"

raise ValueError(
f"No merge function for model type '{type(pipe).__name__}'. "
"Add an entry to DIFFUSION_MERGE_FUNCTIONS in diffusers_utils.py."
)
94 changes: 77 additions & 17 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
from .diffusers_utils import (
generate_diffusion_dummy_forward_fn,
get_diffusion_components,
get_diffusion_model_type,
get_qkv_group_key,
hide_quantizers_from_state_dict,
infer_dtype_from_model,
is_diffusers_object,
is_qkv_projection,
merge_diffusion_checkpoint,
)

HAS_DIFFUSERS = True
Expand Down Expand Up @@ -116,20 +118,49 @@ def _is_enabled_quantizer(quantizer):


def _save_component_state_dict_safetensors(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jingyu-ml @ynankani Do you think it's better to move this function to diffusers_utils.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the diffusers specific code to diffusers_utils.py

component: nn.Module, component_export_dir: Path
component: nn.Module,
component_export_dir: Path,
merged_base_safetensor_path: str | None = None,
hf_quant_config: dict | None = None,
model_type: str | None = None,
) -> None:
"""Save component state dict as safetensors with optional base checkpoint merge.

Args:
component: The nn.Module to save.
component_export_dir: Directory to save model.safetensors and config.json.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will config.json include per layer quantization config as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.json doesn't store the per layer quantization config. It is only embedded in the safetensors file

merged_base_safetensor_path: If provided, merge the exported transformer weights
with non-transformer components (VAE, vocoder, text encoders, etc.) from this
base safetensors file and add quantization metadata to produce a single-file
checkpoint compatible with ComfyUI. This should be the path to a full base
model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
hf_quant_config: If provided, embed quantization config in safetensors metadata
and per-layer _quantization_metadata for ComfyUI.
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
Required when ``merged_base_safetensor_path`` is not None.
"""
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
with open(component_export_dir / "config.json", "w") as f:
json.dump(
{
"_class_name": type(component).__name__,
"_export_format": "safetensors_state_dict",
},
f,
indent=4,
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}

if merged_base_safetensor_path is not None and model_type is not None:
cpu_state_dict, metadata_full = merge_diffusion_checkpoint(
cpu_state_dict, merged_base_safetensor_path, model_type, hf_quant_config
)

metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)

save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
Comment on lines +143 to +159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Metadata is discarded for non-merge exports.

On line 168, metadata is only passed to save_file when merged_base_safetensor_path is not None. For non-merge exports through this function, the _export_format and _class_name metadata (lines 161-162) are computed but thrown away — save_file is called with metadata=None.

If metadata should always be attached (even for non-merge exports), pass it unconditionally:

Proposed fix
     save_file(
         cpu_state_dict,
         str(component_export_dir / "model.safetensors"),
-        metadata=metadata_full if merged_base_safetensor_path is not None else None,
+        metadata=metadata_full,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full if merged_base_safetensor_path is not None else None,
)
metadata: dict[str, str] = {}
metadata_full: dict[str, str] = {}
if merged_base_safetensor_path is not None and model_type is not None:
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path)
if hf_quant_config is not None:
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
# Build per-layer _quantization_metadata for ComfyUI
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
layer_metadata = {}
for k in cpu_state_dict:
if k.endswith((".weight_scale", ".weight_scale_2")):
layer_name = k.rsplit(".", 1)[0]
if layer_name.endswith(".weight"):
layer_name = layer_name.rsplit(".", 1)[0]
if layer_name not in layer_metadata:
layer_metadata[layer_name] = {"format": quant_algo}
metadata_full["_quantization_metadata"] = json.dumps(
{
"format_version": "1.0",
"layers": layer_metadata,
}
)
metadata["_export_format"] = "safetensors_state_dict"
metadata["_class_name"] = type(component).__name__
metadata_full.update(metadata)
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
metadata=metadata_full,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 136 - 169, The
computed metadata (metadata and metadata_full) is discarded for non-merge
exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.


with open(component_export_dir / "config.json", "w") as f:
json.dump(metadata, f, indent=4)


def _collect_shared_input_modules(
model: nn.Module,
Expand Down Expand Up @@ -822,6 +853,7 @@ def _export_diffusers_checkpoint(
dtype: torch.dtype | None,
export_dir: Path,
components: list[str] | None,
merged_base_safetensor_path: str | None = None,
max_shard_size: int | str = "10GB",
) -> None:
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
Expand All @@ -836,6 +868,11 @@ def _export_diffusers_checkpoint(
export_dir: The directory to save the exported checkpoint.
components: Optional list of component names to export. Only used for pipelines.
If None, all components are exported.
merged_base_safetensor_path: If provided, merge the exported transformer weights
with non-transformer components (VAE, vocoder, text encoders, etc.) from this
base safetensors file and add quantization metadata to produce a single-file
checkpoint compatible with ComfyUI. This should be the path to a full base
model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
it will be sharded into multiple files and a .safetensors.index.json will be
created. Use smaller values like "5GB" or "2GB" to force sharding.
Expand All @@ -849,6 +886,9 @@ def _export_diffusers_checkpoint(
warnings.warn("No exportable components found in the model.")
return

# Resolve model type once (only needed when merging with a base checkpoint)
model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None

Comment on lines +889 to +891
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

get_diffusion_model_type is called unconditionally when merged_base_safetensor_path is truthy — will raise ValueError for non-LTX-2 diffusers pipelines.

If a user passes merged_base_safetensor_path for a standard diffusers pipeline (e.g., StableDiffusion), get_diffusion_model_type(pipe) will raise ValueError with a somewhat opaque message. Consider validating earlier or documenting this limitation more prominently in export_hf_checkpoint's docstring.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 885 - 887, The code
unconditionally calls get_diffusion_model_type(pipe) when
merged_base_safetensor_path is set, which will raise a ValueError for non-LTX-2
diffusers; update export_hf_checkpoint to first check the pipeline type (or a
predicate like is_ltx2_pipeline(pipe)) before calling get_diffusion_model_type,
and if merged_base_safetensor_path is provided for an unsupported pipeline
either raise a clearer, descriptive error mentioning export_hf_checkpoint and
merged_base_safetensor_path or document this constraint in the function
docstring so users aren’t met with an opaque ValueError from
get_diffusion_model_type.

# Separate nn.Module components for quantization-aware export
module_components = {
name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module)
Expand Down Expand Up @@ -894,6 +934,7 @@ def _export_diffusers_checkpoint(

# Step 5: Build quantization config
quant_config = get_quant_config(component, is_modelopt_qlora=False)
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None

# Step 6: Save the component
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
Expand All @@ -903,12 +944,15 @@ def _export_diffusers_checkpoint(
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
with hide_quantizers_from_state_dict(component):
_save_component_state_dict_safetensors(component, component_export_dir)

_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
hf_quant_config,
model_type,
)
# Step 7: Update config.json with quantization info
if quant_config is not None:
hf_quant_config = convert_hf_quant_config_format(quant_config)

if hf_quant_config is not None:
config_path = component_export_dir / "config.json"
if config_path.exists():
with open(config_path) as file:
Expand All @@ -920,7 +964,12 @@ def _export_diffusers_checkpoint(
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
Comment on lines 964 to +972
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Non-quantized components also receive merged_base_safetensor_path — unintentional merge?

When a non-quantized component falls through to _save_component_state_dict_safetensors (lines 963-968), it receives merged_base_safetensor_path and model_type. This means the merge function will run on the non-quantized component's state dict too, adding all non-transformer base weights (VAE, vocoder, etc.) into it.

In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components.

Consider guarding by only passing merged_base_safetensor_path for quantized components, or adding a comment clarifying the assumption:

Proposed safeguard
         else:
             _save_component_state_dict_safetensors(
                 component,
                 component_export_dir,
-                merged_base_safetensor_path,
-                model_type=model_type,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(component, component_export_dir)
_save_component_state_dict_safetensors(
component,
component_export_dir,
merged_base_safetensor_path,
model_type=model_type,
)
elif hasattr(component, "save_pretrained"):
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 960 - 968, The
current path calls _save_component_state_dict_safetensors(component,
component_export_dir, merged_base_safetensor_path, model_type=...) for any
component that doesn't implement save_pretrained, which unintentionally applies
the merged_base_safetensor_path merge to non-quantized components; update the
logic so merged_base_safetensor_path is only passed when the component is
quantized (e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.


print(f" Saved to: {component_export_dir}")

Expand Down Expand Up @@ -1044,6 +1093,7 @@ def export_hf_checkpoint(
save_modelopt_state: bool = False,
components: list[str] | None = None,
extra_state_dict: dict[str, torch.Tensor] | None = None,
**kwargs,
):
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).

Expand All @@ -1061,15 +1111,25 @@ def export_hf_checkpoint(
components: Only used for diffusers pipelines. Optional list of component names
to export. If None, all quantized components are exported.
extra_state_dict: Extra state dictionary to add to the exported model.
**kwargs: Internal-only keyword arguments. Supported key: merged_base_safetensor_path
(str, optional). When provided, merges the exported diffusion transformer
weights with non-transformer components (VAE, vocoder, text encoders, etc.)
from this base safetensors file to produce a single-file checkpoint
compatible with ComfyUI. Value should be the path to a full base model
``.safetensors`` file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
Only used for diffusion model exports.
"""
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)

is_diffusers_obj = False
if HAS_DIFFUSERS:
is_diffusers_obj = is_diffusers_object(model)
if is_diffusers_obj:
_export_diffusers_checkpoint(model, dtype, export_dir, components)
_export_diffusers_checkpoint(
model, dtype, export_dir, components, merged_base_safetensor_path
)
return

# Transformers model export
Expand Down