Skip to content

[JAX] MXFP8 Grouped Quant+GEMM#2763

Open
jberchtold-nvidia wants to merge 73 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8
Open

[JAX] MXFP8 Grouped Quant+GEMM#2763
jberchtold-nvidia wants to merge 73 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 14, 2026

Description

TE/JAX integrations of the V2 MXFP8 grouped quantization kernel and the V2 MXFP8 grouped GEMM which are both cuda-graph-safe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new primitive and FFI for V2 grouped quantize that currently only supports MXFP8
  • Extend V2 grouped GEMM to support MXFP8
  • For both V1 and V2, move swizzling from grouped GEMM FFI to grouped quantize FFI. This is required because currently V2 can only do swizzling when fused with quantization; an independent swizzle kernel that supports ragged groups is not available.
    • This entails updating the tests and dequantization logic for Q->DQ tests to support preswizzled scales.
  • Some small kernels added to TE common to handle int32 -> int64 and offset calculations due to JAX's int32 dtype limitation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 24 commits March 9, 2026 15:42
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 14, 2026 17:25
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 14, 2026

Greptile Summary

This PR adds a V2 MXFP8 grouped quantize FFI (nvte_group_quantize) and extends the V2 grouped GEMM to support MXFP8. The key architectural change is moving MXFP8 scale swizzling from the GEMM FFI into the quantize FFI (both V1 and V2), making it transparent to the GEMM. Both the V1 and V2 paths now produce pre-swizzled scale_inv tensors, and the dequantizer was updated accordingly to apply the inverse swizzle. The previously reported NameError and AttributeError bugs have been fixed in this version.

Confidence Score: 5/5

Safe to merge; all remaining findings are style/API-design suggestions that do not affect correctness of the MXFP8 GEMM paths.

Previously reported P1 issues (NameError on lhs_first_dims/lhs_last_dims, AttributeError on None.size) are fixed. No new P0/P1 findings. Three P2 issues remain: a stale comment, a tuple-returning public API that looks boolean, and a missing guard in a helper when shapes are None — none of which affect current call sites.

transformer_engine/jax/cpp_extensions/gemm.py (is_v2_grouped_gemm_supported API signature) and transformer_engine/jax/quantize/dequantizer.py (stale pre_swizzled comment)

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Large refactor: factored out helper functions and added full MXFP8 shape checks to _is_v2_grouped_gemm_supported; is_v2_grouped_gemm_supported now returns tuple[bool,str] (newly public).
transformer_engine/jax/cpp_extensions/quantization.py Adds GroupedQuantizePrimitive V2 path with _use_v2_kernel selector; V1 now also emits pre-swizzled MXFP8 scales via set_with_gemm_swizzled_scales(true).
transformer_engine/jax/csrc/extensions/quantization.cpp Adds GroupedQuantizeV2FFI and handler with FFI_CudaGraph_Traits; int64 workspace for device-side prefix-sum offsets.
transformer_engine/jax/csrc/extensions/gemm.cpp V2 GEMM now supports MXFP8: new make_grouped_tensor overload wires pre-swizzled scale_inv; swizzle pass removed from V1 GEMM loop.
transformer_engine/jax/quantize/dequantizer.py Adds _unswizzle_mxfp8_grouped_scale and uses it unconditionally for MXFP8 in _grouped_dequantize; contains one stale comment about pre_swizzled ownership.
transformer_engine/jax/quantize/tensor.py Adds pre_swizzled field to GroupedScaledTensor1x and threads it through ScaledTensorFactory.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Adds three small helper kernels: int32→int64 with multiplier, exclusive-prefix-sum for tensor offsets, and their nvte_ wrappers.
transformer_engine/jax/flax/module.py make_grouped_dense_cls now supports MXFP8BlockScaling recipe; adds quantization_checkpoint_name parameter.
tests/jax/test_custom_call_compute.py Extends TestGroupedQuantize and TestGroupedDense with group_size_multiplier parameter and V2-eligible shape tuples.

Sequence Diagram

sequenceDiagram
    participant PY as grouped_gemm (Python)
    participant QV1 as GroupedQuantizeV1 FFI
    participant QV2 as GroupedQuantizeV2 FFI
    participant GV1 as GroupedGemmV1 FFI
    participant GV2 as GroupedGemmV2 FFI

    PY->>PY: _use_v2_kernel(scaling_mode, shape, flatten_axis)
    alt V2 eligible (SM100+, dims 128-aligned)
        PY->>QV2: nvte_group_quantize (fused swizzle)
        QV2-->>PY: GroupedScaledTensor1x (pre_swizzled=True)
    else V1 fallback
        PY->>QV1: nvte_quantize_grouped + set_with_gemm_swizzled_scales(true)
        QV1-->>PY: GroupedScaledTensor1x (pre_swizzled=True)
    end

    PY->>PY: is_v2_grouped_gemm_supported(scaling_mode, shapes, boundaries)
    alt V2 GEMM eligible (SM100+, dims 128-aligned)
        PY->>GV2: nvte_grouped_gemm (pre-swizzled scale_inv consumed directly)
        GV2-->>PY: output
    else V1 GEMM fallback
        PY->>GV1: nvte_multi_tensor_gemm (no re-swizzle, scales already swizzled)
        GV1-->>PY: output
    end

    note over PY: Dequantize path (tests only)
    PY->>PY: _unswizzle_mxfp8_grouped_scale(scale_inv_flat, padded_2d)
    PY->>PY: BlockScaleDequantizer._dequantize_func(data, unswizzled_scale)
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' into jberchtold/gmm-..." | Re-trigger Greptile

Comment on lines +1028 to +1031
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert False makes fallback unreachable

The assert False statements at lines 1028, 1036, and 1045 will always raise AssertionError before the return False on the next line, making those returns dead code. More critically, if Python is run with optimizations enabled (-O flag, which disables asserts), the assert False becomes a no-op and execution falls through — the function would silently skip the validation and continue to later checks or return True, potentially routing data to the V2 kernel under unsupported conditions.

These should be changed to raise an explicit exception or simply return False (if fallback to V1 is the intended behavior) without using assert:

Suggested change
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
return False

This same pattern repeats at lines 1036-1039 and 1044-1048.

Comment on lines +1078 to +1085
cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
}
// size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
// if (!is_rhs_ragged) {
// NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out group_sizes sum validation

The validation that sum(group_sizes) matches m (or k for wgrad) has been commented out entirely. While the new *_first_dims/*_last_dims interface changes how dimensions are communicated, removing this runtime sanity check eliminates a useful guard against dimension mismatches that could lead to silent data corruption or out-of-bounds memory access. Consider either adapting this validation to work with the new interface or adding an equivalent check.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 3 commits April 7, 2026 15:27
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia changed the title [JAX] MXFP8 Grouped GEMM [JAX] MXFP8 Grouped Quant+GEMM Apr 8, 2026
@tdophung tdophung marked this pull request as ready for review April 8, 2026 21:58
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]

is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100
v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should wrap this into utils somewhere, and reuse to guard the all calls to V2 grouped GEMM, not just from test_custom_call_compute

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably would also make the def grouped_gemm in gemm.py shorter?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I've decided to simplify the test code to make it less coupled to V1/V2. I still have some comments to indicate which test cases should trigger V1/V2, but there is less V1/V2 logic in the tests themselves and it is left as more of an internal implementation detail.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, I've also simplified the grouped_gemm function as I agree that function body was too complex. It is now refactored into several helper functions. It could be cleaned up further, but it's at least better than it was previously. Thanks!


# *32 so that the input shapes works for MXFP8
input_shape = (m * 32, n)
# Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it clearer that the 128 aligned is a cuBLASLt thing while 32 multiplier is a scaling factor for MXFP8 applying to chunks of 32 elements thing

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't solely due to cuBLASLt. The grouped quantize kernel also has these alignment requirements. I've refactored this test code to be less coupled to the internal V1/V2 logic and instead tried to select a handful of test cases that should cover both V1 and V2, and whether V1 or V2 is selected is more of an implementation detail than visible at the test-level (except for some small notes next to the configs to show both V1 and V2 should be covered).

return False
# V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both
# operands is a multiple of 128. The V2 GEMM setup kernel computes per-group
# scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this took me a bit to understand, not sure if you should clarify what K_blocks is as it is not defined in this file. If after 2nd read and it still feels pretty trivial then feel free to SR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reworded this so it's clearer. I believe we could support cases where this dim is not divisible by 128, there is no inherent limitation in the GEMM afaik. But currently the grouped quantize and grouped GEMM setup kernels do not handle these offsets correctly except for when this dim is divisible by 128 for simplicity

# [n_groups int64 group_sizes | n_groups+1 int64 offsets]
# = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8.
n_groups = group_sizes_aval.size
fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fifth output seems like a bad name for this. Maybe group_sizes_and_offsets?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see that it is updated_amax for V1. Not sure what would be the best name here given that it is different purposes in the 2 versions

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this is a bad name. Instead of this overloaded 5th output, I've instead made both FFIs use 6 outputs and left the workspace empty on V1 for consistency. For V2, if we ever want to support delayed scaling we would need this updated amax output anyways

if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING:
return False
# Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM.
if get_min_device_compute_capability() < 100:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in gemm.py, you check for get_device_compute_capability but here it is get_min_device_capability. These would be okay if all GPUs on the systemm is the same compute cap (which is most of our products, maybe minus Galaxy ones, I don't remember clearly). But for consistency, please use the same thing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the test file too

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks! I've updated to get the changes in this PR to use min device capability


// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim).
// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small.
__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an idea for this in case n_groups ever gets large: do 32 threads cumsum in blocks then warp shfl to reduce local sums to 1 sum.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Currently the kernel runtime is pretty small relative to our other kernels and our n_groups per device is fairly small with EP, but good idea for future if n_groups per device gets bigger

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-mxfp8

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

jberchtold-nvidia and others added 2 commits April 11, 2026 10:47
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants