[JAX] Add debug validation mode for runtime group size alignment#2867
[JAX] Add debug validation mode for runtime group size alignment#2867jberchtold-nvidia wants to merge 73 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-mxfp8 Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a debug-only opt-in (
Confidence Score: 3/5Not safe to merge: the primary feature (raising errors on misalignment) does not propagate through the XLA FFI boundary in unit-test contexts, and the alignment check produces false positives for V1 MXFP8. Two P1 defects: (1) NVTE_CHECK throws a C++ exception instead of returning XLA_FFI_Error*, which the PR author acknowledges causes silent failures in isolated tests; (2) hardcoded 128-byte alignment in grouped_dense incorrectly rejects valid V1 MXFP8 configurations. Non-validation changes appear sound. transformer_engine/jax/csrc/extensions/validation.cpp (error return pattern) and transformer_engine/jax/dense.py (recipe-aware alignment selection) Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python caller
participant GD as grouped_dense()
participant VGS as validate_group_sizes (JAX FFI)
participant VGSC as ValidateGroupSizesFFI (C++)
participant GDense as _grouped_dense()
User->>GD: grouped_dense(x, kernel, group_sizes, ...)
alt NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT=1
GD->>VGS: validate_group_sizes(group_sizes, align=128)
VGS->>VGSC: cudaMemcpyAsync D2H + cudaStreamSynchronize
alt all group_sizes[i] % 128 == 0
VGSC-->>VGS: return aliased buffer (no-op)
VGS-->>GD: group_sizes (pass-through)
else misaligned group detected
VGSC->>VGSC: NVTE_CHECK throws C++ exception
Note over VGSC,VGS: Exception may not propagate to Python in all XLA contexts
end
end
GD->>GDense: _grouped_dense(x, kernel, group_sizes, ...)
GDense-->>User: output tensor
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| """ | ||
| if _should_validate_group_size_alignment(): |
There was a problem hiding this comment.
Hardcoded 128-alignment rejects valid V1 MXFP8 and BF16 inputs
validate_group_sizes(..., 128) is applied unconditionally regardless of quantizer_set, so any call with V1 MXFP8 (where only 32-alignment is needed) and group sizes that are multiples of 32 but not 128 will raise a false-positive error. For BF16 / non-quantized paths there is no alignment requirement at all, so the check would reject all non-128-aligned groups there too. The debug mode effectively becomes unusable for those paths until the TODO is resolved.
| """ | |
| if _should_validate_group_size_alignment(): | |
| if _should_validate_group_size_alignment(): | |
| # Determine required alignment from the active recipe. | |
| # V2 MXFP8 (Blackwell) requires 128; V1 MXFP8 requires 32; BF16 has no requirement. | |
| _scaling_mode = quantizer_set.x.scaling_mode if quantizer_set is not noop_quantizer_set else None | |
| if _scaling_mode is not None: | |
| _align = 128 if is_fp8_gemm_with_all_layouts_supported() else 32 | |
| group_sizes = tex.validate_group_sizes(group_sizes, _align) |
| for (int64_t i = 0; i < num_experts; ++i) { | ||
| NVTE_CHECK(group_sizes_host[i] % align_size == 0, | ||
| "group_sizes alignment check failed: group_sizes[", i, "] = ", group_sizes_host[i], | ||
| " is not divisible by align_size = ", align_size); | ||
| } |
There was a problem hiding this comment.
NVTE_CHECK throws through the XLA FFI boundary; error propagation is unreliable
XLA FFI handlers are expected to signal errors by returning an XLA_FFI_Error* value; they must not throw C++ exceptions across the ABI boundary. NVTE_CHECK likely throws a std::runtime_error, which is undefined behaviour when it crosses the FFI boundary. The PR description acknowledges this: "isolated unit tests detect the failure with printf but for some reason do not raise the error properly thru XLA in isolated tests." The correct pattern is to return an explicit error token instead of asserting:
for (int64_t i = 0; i < num_experts; ++i) {
if (group_sizes_host[i] % align_size != 0) {
return XLA_FFI_Error_Create(
XLA_FFI_Error_Code_INVALID_ARGUMENT,
("group_sizes alignment check failed: group_sizes[" + std::to_string(i) +
"] = " + std::to_string(group_sizes_host[i]) +
" is not divisible by align_size = " + std::to_string(align_size))
.c_str());
}
}Until this is fixed the debug mode cannot reliably surface misalignment errors in standard unit-test contexts.
|
|
||
| @cache | ||
| def _should_validate_group_size_alignment() -> bool: |
There was a problem hiding this comment.
@cache silences dynamic env-var changes after first call
Because _should_validate_group_size_alignment is memoised, any process that calls grouped_dense once with the env var unset will permanently disable the check for its lifetime, even if NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT is set later (e.g., in a test fixture that modifies os.environ mid-run). A comment should call this out explicitly so developers aren't surprised.
| @cache | |
| def _should_validate_group_size_alignment() -> bool: | |
| @cache | |
| def _should_validate_group_size_alignment() -> bool: | |
| # Evaluated once per process at first call. Set NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT=1 | |
| # before the first call to grouped_dense; changing it afterwards has no effect. | |
| return os.getenv("NVTE_JAX_VALIDATE_GROUP_SIZE_ALIGNMENT", "0") == "1" |
Description
Works in MaxText E2E and catches incorrect group size alignment. However, isolated unit tests detect the failure with printf but for some reason do not raise the error properly thru XLA in isolated tests like they do in MaxText. Investigating
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: