Skip to content

[NVBug 5702186] Fix awq model export for Gemma3#793

Merged
meenchen merged 1 commit intomainfrom
weimingc/fix_int4awq_gemma
Jan 18, 2026
Merged

[NVBug 5702186] Fix awq model export for Gemma3#793
meenchen merged 1 commit intomainfrom
weimingc/fix_int4awq_gemma

Conversation

@meenchen
Copy link
Contributor

@meenchen meenchen commented Jan 16, 2026

What does this PR do?

Type of change: Bug fix

Overview: norms laers in Gemma that use (1 + weight) in forward, we will fold pre_quant_scale into the effective weight. That is to find folded w' subject to: 1 + w' = (1 + w) * s => w' = (1 + w) * s -1

Usage

# Add a code snippet demonstrating how to use this

Testing

./scripts/huggingface_example.sh --model google/gemma-3-1b-it --quant int4_awq

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • Improvements
    • Enhanced quantization utilities to better handle various LayerNorm variants and normalization patterns, including support for weight-offset variants and zero-centered gamma configurations.
    • Optimized pre-quantization layer normalization fusion to apply conditional weight scaling strategies based on normalization type.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
@meenchen meenchen requested a review from a team as a code owner January 16, 2026 22:51
@meenchen meenchen requested a review from Edwardf0t1 January 16, 2026 22:51
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

📝 Walkthrough

Walkthrough

A single file in the quantization export module is refactored to introduce conditional weight-folding logic for LayerNorm fusion. A new helper function detects LayerNorm variants using weight-plus-one patterns, and the pre-quantization scale folding is now conditionally applied either as direct weight multiplication or via the detected weight-plus-one mechanism.

Changes

Cohort / File(s) Summary
LayerNorm quantization fusion logic
modelopt/torch/export/quant_utils.py
Added _layernorm_uses_weight_plus_one() helper to detect LayerNorm variants (e.g., LayerNorm1P, Gemma RMSNorm) with zero-centered gamma. Refactored fuse_prequant_layernorm() to conditionally fold pre_quant_scale: applies (weight + 1) pattern for detected variants, otherwise simple multiplication. Bias folding updated to use pre_quant_scale directly.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the specific bug being fixed (awq model export for Gemma3) and references the NVBug ticket, directly matching the core change in the PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@meenchen meenchen requested a review from cjluo-nv January 16, 2026 22:51
@codecov
Copy link

codecov bot commented Jan 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.22%. Comparing base (db76b1e) to head (707140c).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #793      +/-   ##
==========================================
- Coverage   74.23%   74.22%   -0.01%     
==========================================
  Files         192      192              
  Lines       19033    19035       +2     
==========================================
  Hits        14129    14129              
- Misses       4904     4906       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def _layernorm_uses_weight_plus_one(module: torch.nn.Module) -> bool:
if any(
name in type(module).__name__
for name in ["LayerNorm1P", "GemmaRMSNorm", "Gemma2RMSNorm", "Gemma3RMSNorm"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does LayerNorm1P appears in Gemma3 only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meenchen meenchen self-assigned this Jan 16, 2026
@meenchen meenchen merged commit 38fb120 into main Jan 18, 2026
45 of 49 checks passed
@meenchen meenchen deleted the weimingc/fix_int4awq_gemma branch January 18, 2026 04:48
kevalmorabia97 pushed a commit that referenced this pull request Jan 19, 2026
## What does this PR do?

**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

**Overview:** norms laers in Gemma that use (1 + weight) in forward, we
will fold pre_quant_scale into the effective weight. That is to find
folded w' subject to: `1 + w' = (1 + w) * s` => `w' = (1 + w) * s -1`

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

./scripts/huggingface_example.sh --model google/gemma-3-1b-it --quant
int4_awq

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Improvements**
* Enhanced quantization utilities to better handle various LayerNorm
variants and normalization patterns, including support for weight-offset
variants and zero-centered gamma configurations.
* Optimized pre-quantization layer normalization fusion to apply
conditional weight scaling strategies based on normalization type.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
danielkorzekwa pushed a commit that referenced this pull request Feb 17, 2026
## What does this PR do?

**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

**Overview:** norms laers in Gemma that use (1 + weight) in forward, we
will fold pre_quant_scale into the effective weight. That is to find
folded w' subject to: `1 + w' = (1 + w) * s` => `w' = (1 + w) * s -1`

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

./scripts/huggingface_example.sh --model google/gemma-3-1b-it --quant
int4_awq

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Improvements**
* Enhanced quantization utilities to better handle various LayerNorm
variants and normalization patterns, including support for weight-offset
variants and zero-centered gamma configurations.
* Optimized pre-quantization layer normalization fusion to apply
conditional weight scaling strategies based on normalization type.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants