Skip to content

[JAX] Add debug validation mode for runtime group size alignment#2867

Draft
jberchtold-nvidia wants to merge 73 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/group-size-align-validation
Draft

[JAX] Add debug validation mode for runtime group size alignment#2867
jberchtold-nvidia wants to merge 73 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/group-size-align-validation

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Works in MaxText E2E and catches incorrect group size alignment. However, isolated unit tests detect the failure with printf but for some reason do not raise the error properly thru XLA in isolated tests like they do in MaxText. Investigating

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 30 commits March 9, 2026 15:42
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 26 commits April 6, 2026 13:53
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-mxfp8

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft April 11, 2026 16:08
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 11, 2026

Greptile Summary

This PR adds a debug-only opt-in (NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT=1) that copies group_sizes to host and asserts divisibility before each grouped GEMM, along with MXFP8 scale un-swizzle logic in the dequantizer and V1/V2 MXFP8 test parametrization. Two blocking issues remain before this is ready to merge:

  • The C++ FFI handler signals errors via NVTE_CHECK (which throws a C++ exception) rather than returning an XLA_FFI_Error*; the PR description explicitly notes that the check fires with printf in isolated tests but the error is not surfaced to Python — the core feature does not work reliably.
  • The alignment threshold is hardcoded to 128 for all recipes; V1 MXFP8 requires only 32-alignment, so the validator produces false-positive failures for valid V1 inputs when the env var is set.

Confidence Score: 3/5

Not safe to merge: the primary feature (raising errors on misalignment) does not propagate through the XLA FFI boundary in unit-test contexts, and the alignment check produces false positives for V1 MXFP8.

Two P1 defects: (1) NVTE_CHECK throws a C++ exception instead of returning XLA_FFI_Error*, which the PR author acknowledges causes silent failures in isolated tests; (2) hardcoded 128-byte alignment in grouped_dense incorrectly rejects valid V1 MXFP8 configurations. Non-validation changes appear sound.

transformer_engine/jax/csrc/extensions/validation.cpp (error return pattern) and transformer_engine/jax/dense.py (recipe-aware alignment selection)

Important Files Changed

Filename Overview
transformer_engine/jax/dense.py Wires in the new validation primitive; hardcoded 128-alignment causes false positives for V1 MXFP8 (needs 32) and BF16 (no requirement).
transformer_engine/jax/csrc/extensions/validation.cpp New FFI handler; uses NVTE_CHECK (throws) instead of returning XLA_FFI_Error*, causing unreliable error propagation acknowledged by the PR author.
transformer_engine/jax/cpp_extensions/validation.py New JAX FFI primitive wrapping the C++ validate handler; correct pass-through aliasing and abstract shape rules.
transformer_engine/jax/quantize/dequantizer.py Adds MXFP8 scale un-swizzle for grouped dequantization, removes misplaced module-level @staticmethod, fixes int() cast for group_sizes indexing.
transformer_engine/jax/flax/module.py Allows MXFP8BlockScaling for grouped GEMM and threads quantization_checkpoint_name through wrap_function_in_te_state_module.
tests/jax/test_custom_call_compute.py Adds group_size_multiplier parametrization to exercise V1/V2 MXFP8 paths, with correct pytest.skip guard for V2+non-128-aligned combinations.
transformer_engine/jax/cpp_extensions/gemm.py Refactors grouped GEMM helper functions for clarity; no functional regressions apparent.

Sequence Diagram

sequenceDiagram
    participant User as Python caller
    participant GD as grouped_dense()
    participant VGS as validate_group_sizes (JAX FFI)
    participant VGSC as ValidateGroupSizesFFI (C++)
    participant GDense as _grouped_dense()

    User->>GD: grouped_dense(x, kernel, group_sizes, ...)
    alt NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT=1
        GD->>VGS: validate_group_sizes(group_sizes, align=128)
        VGS->>VGSC: cudaMemcpyAsync D2H + cudaStreamSynchronize
        alt all group_sizes[i] % 128 == 0
            VGSC-->>VGS: return aliased buffer (no-op)
            VGS-->>GD: group_sizes (pass-through)
        else misaligned group detected
            VGSC->>VGSC: NVTE_CHECK throws C++ exception
            Note over VGSC,VGS: Exception may not propagate to Python in all XLA contexts
        end
    end
    GD->>GDense: _grouped_dense(x, kernel, group_sizes, ...)
    GDense-->>User: output tensor
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines 357 to +358
"""
if _should_validate_group_size_alignment():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded 128-alignment rejects valid V1 MXFP8 and BF16 inputs

validate_group_sizes(..., 128) is applied unconditionally regardless of quantizer_set, so any call with V1 MXFP8 (where only 32-alignment is needed) and group sizes that are multiples of 32 but not 128 will raise a false-positive error. For BF16 / non-quantized paths there is no alignment requirement at all, so the check would reject all non-128-aligned groups there too. The debug mode effectively becomes unusable for those paths until the TODO is resolved.

Suggested change
"""
if _should_validate_group_size_alignment():
if _should_validate_group_size_alignment():
# Determine required alignment from the active recipe.
# V2 MXFP8 (Blackwell) requires 128; V1 MXFP8 requires 32; BF16 has no requirement.
_scaling_mode = quantizer_set.x.scaling_mode if quantizer_set is not noop_quantizer_set else None
if _scaling_mode is not None:
_align = 128 if is_fp8_gemm_with_all_layouts_supported() else 32
group_sizes = tex.validate_group_sizes(group_sizes, _align)

Comment on lines +32 to +36
for (int64_t i = 0; i < num_experts; ++i) {
NVTE_CHECK(group_sizes_host[i] % align_size == 0,
"group_sizes alignment check failed: group_sizes[", i, "] = ", group_sizes_host[i],
" is not divisible by align_size = ", align_size);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NVTE_CHECK throws through the XLA FFI boundary; error propagation is unreliable

XLA FFI handlers are expected to signal errors by returning an XLA_FFI_Error* value; they must not throw C++ exceptions across the ABI boundary. NVTE_CHECK likely throws a std::runtime_error, which is undefined behaviour when it crosses the FFI boundary. The PR description acknowledges this: "isolated unit tests detect the failure with printf but for some reason do not raise the error properly thru XLA in isolated tests." The correct pattern is to return an explicit error token instead of asserting:

  for (int64_t i = 0; i < num_experts; ++i) {
    if (group_sizes_host[i] % align_size != 0) {
      return XLA_FFI_Error_Create(
          XLA_FFI_Error_Code_INVALID_ARGUMENT,
          ("group_sizes alignment check failed: group_sizes[" + std::to_string(i) +
           "] = " + std::to_string(group_sizes_host[i]) +
           " is not divisible by align_size = " + std::to_string(align_size))
              .c_str());
    }
  }

Until this is fixed the debug mode cannot reliably surface misalignment errors in standard unit-test contexts.

Comment on lines 318 to +320

@cache
def _should_validate_group_size_alignment() -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 @cache silences dynamic env-var changes after first call

Because _should_validate_group_size_alignment is memoised, any process that calls grouped_dense once with the env var unset will permanently disable the check for its lifetime, even if NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT is set later (e.g., in a test fixture that modifies os.environ mid-run). A comment should call this out explicitly so developers aren't surprised.

Suggested change
@cache
def _should_validate_group_size_alignment() -> bool:
@cache
def _should_validate_group_size_alignment() -> bool:
# Evaluated once per process at first call. Set NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT=1
# before the first call to grouped_dense; changing it afterwards has no effect.
return os.getenv("NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT", "0") == "1"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant