Skip to content

Optimizations for MXFP8/NVFP4 dequantize kernels#2865

Draft
YigongQin wants to merge 4 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim
Draft

Optimizations for MXFP8/NVFP4 dequantize kernels#2865
YigongQin wants to merge 4 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim

Conversation

@YigongQin
Copy link
Copy Markdown

@YigongQin YigongQin commented Apr 10, 2026

Description

  • Handle empty tensors in dequantize for CUDA graph compatibility
  • Add swizzled scale support to the NVFP4 dequantize kernel, reusing the existing MXFP8 swizzle index computation
  • Add C++ unit tests for both NVFP4 and MXFP8 dequantization (including swizzled scale variants)
  • Fix to_cpu() and set_scale() in test infrastructure to correctly sync amax/scale for NVTE_NVFP4_1D_SCALING mode

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Handle empty tensors in dequantize for CUDA graph compatibility — Early return when input has zero elements, avoiding kernel launches on empty tensors.
  • Add GEMM-swizzled scale support to NVFP4 dequantize kernel — Template the kernel with WITH_GEMM_SWIZZLED_SCALES to support reading scales from swizzled layout, reusing the MXFP8 swizzle index computation.
  • Add GEMM-swizzled scale support to MXFP8 dequantize kernel — Extend the MXFP8 dequantize kernel to handle swizzled scale inputs.
  • Add C++ unit tests for NVFP4 dequantization — 21 tests for compact scales + 21 tests for swizzled scales, covering multiple sizes and output dtypes (fp32, bf16, fp16).
  • Add C++ unit tests for MXFP8 dequantization with swizzled scales — New swizzled test suite for MXFP8.
  • Fix to_cpu() to sync amax/scale for NVFP4 tensors — Previously only synced for NVTE_DELAYED_TENSOR_SCALING, causing the CPU reference to use stale amax=0.
  • Fix set_scale() to work for NVFP4 tensors — Same condition fix, enabling the scale to be properly uploaded to GPU before quantization.
  • Fix swizzled test ordering — Move from_cpu() before the FP4 data copy to prevent from_cpu() from overwriting the copied data with zeros.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from f5e7375 to 39c0fb1 Compare April 10, 2026 22:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant