Skip to content

Adds GEMM Profiling Guide to TE#2863

Open
jomitchellnv wants to merge 2 commits intoNVIDIA:mainfrom
jomitchellnv:jm/gemm-blog
Open

Adds GEMM Profiling Guide to TE#2863
jomitchellnv wants to merge 2 commits intoNVIDIA:mainfrom
jomitchellnv:jm/gemm-blog

Conversation

@jomitchellnv
Copy link
Copy Markdown
Contributor

Description

Adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool. The guide
explains how to derive all 12 per-layer GEMM shapes (Fprop, Dgrad, Wgrad) from transformer model
hyperparameters, benchmark them across precisions (BF16, FP8 Block, MXFP8, NVFP4), and interpret the resulting
speedup estimates.

The benchmark tool supports two modes: model config mode (derives shapes automatically from hidden_size,
intermediate_size, etc.) and manual shape mode (explicit MxKxN triplets). It measures both autocast performance
(realistic end-to-end with quantization overhead) and pre-quantized kernel-only throughput, using CUDA events
or torch.profiler timing backends.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add benchmarks/gemm/benchmark_gemm.py — standalone GEMM benchmark tool supporting BF16, FP8 Block, MXFP8, and
    NVFP4 precisions with autocast and pre-quantized modes, CUDA event and torch.profiler timing, Nsight Systems
    integration, and bar-chart output

  • Add docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst — documentation covering GEMM
    shape derivation from model configs, forward/backward pass shape conventions, precision mapping per GEMM pass,
    speedup calculation methodology, and a worked example on B300

  • Add benchmark result plots (img/model_config_speedup.png, img/model_config_speedup_prequant.png)

  • Update docs/features/low_precision_training/index.rst toctree to include the new guide
    Please list the changes introduced in this PR:

  • Change A

  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv changed the title adds blog post Adds GEMM Profiling Guide to TE Apr 9, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR adds a standalone GEMM benchmark tool (benchmarks/gemm/benchmark_gemm.py) and an accompanying RST guide covering shape derivation, precision comparison (BF16, FP8 Block, MXFP8, NVFP4), and speedup interpretation. Three issues were flagged in prior review threads (FP8Block silently omitted in shape mode, dead if i == 0 or True condition, docstring "cublas"/"gemm" mismatch); two new P2 items remain: the model-config summary omits the FP8Block vs BF16 speedup ratio, and compute_gemm_shapes hard-codes MHA (N_qkv = 3 * H) so GQA/MQA model configs produce incorrect QKV shapes silently.

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/completeness suggestions that do not affect correctness of the primary benchmark paths.

Three previously flagged issues (FP8Block shape-mode omission, dead or True condition, docstring mismatch) are the most substantive concerns. The two new items — missing FP8Block speedup line and MHA hard-coding — are P2 informational gaps that don't break anything. No P0/P1 issues remain.

benchmarks/gemm/benchmark_gemm.py — speedup summary and compute_gemm_shapes warrant a follow-up pass.

Vulnerabilities

No security concerns identified. The tool is a local benchmarking script with no network I/O, no user-supplied data fed into privileged operations, and no secrets handling.

Important Files Changed

Filename Overview
benchmarks/gemm/benchmark_gemm.py New 1609-line GEMM benchmark tool; three pre-existing issues flagged in prior review threads (FP8Block silently omitted in shape mode, dead if i == 0 or True condition, docstring "cublas"/"gemm" mismatch). Two new P2 issues: FP8Block vs BF16 speedup omitted from model-config summary, and QKV shape derivation hard-codes MHA (breaks GQA/MQA models silently).
docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst New RST guide covering GEMM shape derivation, precision modes, and benchmark interpretation; well-structured with worked examples and figures. No doc-level issues found.
docs/features/low_precision_training/index.rst Minor toctree update to include the new gemm_profiling guide; correct and minimal.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[CLI args] --> B{Mode?}
    B -- "--hidden_size etc." --> C[Model Config Mode]
    B -- "--shapes / default squares" --> D[Shape Mode]
    B -- "--profile" --> E[Nsight Profile Mode]

    C --> F[compute_gemm_shapes\nFprop + Dgrad + Wgrad]
    F --> G[run_model_config_benchmarks]
    G --> H[_benchmark_single_shape\nBF16 / FP8Block / MXFP8 / NVFP4]
    H --> I[Per-layer + full-model\nspeedup summary]
    I --> J[create_model_config_plot\nstacked bar chart]

    D --> K[parse_shapes_arg\nor get_default_shapes]
    K --> L[run_benchmarks\nBF16 / MXFP8 / NVFP4 only]
    L --> M[create_plot\ngrouped bar chart]

    E --> N[run_benchmarks\nsingle shape with\ncudaProfilerStart/Stop + NVTX]

    H --> O{Timing backend?}
    O -- cuda-events --> P[_time_with_cuda_events\nleading-kernel trick]
    O -- profiler --> Q[_time_with_profiler\nCUPTI kernel extraction]
Loading

Reviews (2): Last reviewed commit: "fixes newline issue" | Re-trigger Greptile

Comment on lines +794 to +799
results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []}
time_results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []}

has_blackwell = is_blackwell_available()
run_fp8 = include_fp8 and TE_AVAILABLE
run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 FP8Block silently omitted in shape mode

run_benchmarks() (used for both default square-shape benchmarks and explicit --shapes invocations) never calls benchmark_fp8_block / benchmark_fp8_block_prequantized. The results dict is initialized with only "BF16", "MXFP8", and "NVFP4", and the function has no include_fp8_block parameter — so the --no-fp8-block flag parsed in main() is only forwarded to run_model_config_benchmarks (line 1579) and has no effect here.

Users who run the tool in shape mode (no model-config flags) will silently receive BF16/MXFP8/NVFP4 data only, even though the module docstring advertises "BF16, FP8 Block, MXFP8, and NVFP4 precisions."

To fix, add include_fp8_block: bool = True to run_benchmarks, initialise results["FP8Block"] = [], select fp8_block_fn the same way model-config mode does, and forward the flag from main().

Comment on lines +1355 to +1367
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead condition if i == 0 or True always evaluates to True

Both label= expressions use if i == 0 or True, which unconditionally takes the True branch. This is dead code — or True makes the condition tautological. The intent was likely either True (always label, which is fine for a stacked bar chart) or if i == 0 (label only the first series). Clean it up to express intent clearly:

Suggested change
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
label=f"{op_label} (Fprop+Dgrad)",

and

Suggested change
color=op_color,
alpha=0.9,
label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "",
)
ax.bar(
x,
wgrad_ms,
bar_width,
bottom=all_fprop_total + total_wgrad_bottom,
color=op_color,
alpha=0.5,
label=f"{op_label} (Wgrad)" if i == 0 or True else "",
)
label=f"{op_label} (Wgrad)",

Comment on lines +18 to +21
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas)
are summed, giving a kernel-only measurement.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Docstring lists "cublas" but the pattern tuple uses "gemm" instead

The module docstring (line 19) lists the matched kernel patterns as (nvjet, xmma, cutlass, cublas), but GEMM_KERNEL_PATTERNS at line 70 is ("gemm", "nvjet", "xmma", "cutlass")"cublas" is absent and "gemm" was added in its place. In practice "gemm" does catch cuBLAS kernels (their names contain gemm), so the behaviour is correct, but the docstring is inaccurate and may confuse users auditing kernel coverage.

Suggested change
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas)
are summed, giving a kernel-only measurement.
* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps.
Only the matched GEMM compute kernels (gemm, nvjet, xmma, cutlass)
are summed, giving a kernel-only measurement.

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant