Conversation
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Greptile SummaryThis PR adds a standalone GEMM benchmark tool ( Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/completeness suggestions that do not affect correctness of the primary benchmark paths. Three previously flagged issues (FP8Block shape-mode omission, dead benchmarks/gemm/benchmark_gemm.py — speedup summary and compute_gemm_shapes warrant a follow-up pass.
|
| Filename | Overview |
|---|---|
| benchmarks/gemm/benchmark_gemm.py | New 1609-line GEMM benchmark tool; three pre-existing issues flagged in prior review threads (FP8Block silently omitted in shape mode, dead if i == 0 or True condition, docstring "cublas"/"gemm" mismatch). Two new P2 issues: FP8Block vs BF16 speedup omitted from model-config summary, and QKV shape derivation hard-codes MHA (breaks GQA/MQA models silently). |
| docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst | New RST guide covering GEMM shape derivation, precision modes, and benchmark interpretation; well-structured with worked examples and figures. No doc-level issues found. |
| docs/features/low_precision_training/index.rst | Minor toctree update to include the new gemm_profiling guide; correct and minimal. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[CLI args] --> B{Mode?}
B -- "--hidden_size etc." --> C[Model Config Mode]
B -- "--shapes / default squares" --> D[Shape Mode]
B -- "--profile" --> E[Nsight Profile Mode]
C --> F[compute_gemm_shapes\nFprop + Dgrad + Wgrad]
F --> G[run_model_config_benchmarks]
G --> H[_benchmark_single_shape\nBF16 / FP8Block / MXFP8 / NVFP4]
H --> I[Per-layer + full-model\nspeedup summary]
I --> J[create_model_config_plot\nstacked bar chart]
D --> K[parse_shapes_arg\nor get_default_shapes]
K --> L[run_benchmarks\nBF16 / MXFP8 / NVFP4 only]
L --> M[create_plot\ngrouped bar chart]
E --> N[run_benchmarks\nsingle shape with\ncudaProfilerStart/Stop + NVTX]
H --> O{Timing backend?}
O -- cuda-events --> P[_time_with_cuda_events\nleading-kernel trick]
O -- profiler --> Q[_time_with_profiler\nCUPTI kernel extraction]
Reviews (2): Last reviewed commit: "fixes newline issue" | Re-trigger Greptile
| results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []} | ||
| time_results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []} | ||
|
|
||
| has_blackwell = is_blackwell_available() | ||
| run_fp8 = include_fp8 and TE_AVAILABLE | ||
| run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell |
There was a problem hiding this comment.
FP8Block silently omitted in shape mode
run_benchmarks() (used for both default square-shape benchmarks and explicit --shapes invocations) never calls benchmark_fp8_block / benchmark_fp8_block_prequantized. The results dict is initialized with only "BF16", "MXFP8", and "NVFP4", and the function has no include_fp8_block parameter — so the --no-fp8-block flag parsed in main() is only forwarded to run_model_config_benchmarks (line 1579) and has no effect here.
Users who run the tool in shape mode (no model-config flags) will silently receive BF16/MXFP8/NVFP4 data only, even though the module docstring advertises "BF16, FP8 Block, MXFP8, and NVFP4 precisions."
To fix, add include_fp8_block: bool = True to run_benchmarks, initialise results["FP8Block"] = [], select fp8_block_fn the same way model-config mode does, and forward the flag from main().
| color=op_color, | ||
| alpha=0.9, | ||
| label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "", | ||
| ) | ||
| ax.bar( | ||
| x, | ||
| wgrad_ms, | ||
| bar_width, | ||
| bottom=all_fprop_total + total_wgrad_bottom, | ||
| color=op_color, | ||
| alpha=0.5, | ||
| label=f"{op_label} (Wgrad)" if i == 0 or True else "", | ||
| ) |
There was a problem hiding this comment.
Dead condition
if i == 0 or True always evaluates to True
Both label= expressions use if i == 0 or True, which unconditionally takes the True branch. This is dead code — or True makes the condition tautological. The intent was likely either True (always label, which is fine for a stacked bar chart) or if i == 0 (label only the first series). Clean it up to express intent clearly:
| color=op_color, | |
| alpha=0.9, | |
| label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "", | |
| ) | |
| ax.bar( | |
| x, | |
| wgrad_ms, | |
| bar_width, | |
| bottom=all_fprop_total + total_wgrad_bottom, | |
| color=op_color, | |
| alpha=0.5, | |
| label=f"{op_label} (Wgrad)" if i == 0 or True else "", | |
| ) | |
| label=f"{op_label} (Fprop+Dgrad)", |
and
| color=op_color, | |
| alpha=0.9, | |
| label=f"{op_label} (Fprop+Dgrad)" if i == 0 or True else "", | |
| ) | |
| ax.bar( | |
| x, | |
| wgrad_ms, | |
| bar_width, | |
| bottom=all_fprop_total + total_wgrad_bottom, | |
| color=op_color, | |
| alpha=0.5, | |
| label=f"{op_label} (Wgrad)" if i == 0 or True else "", | |
| ) | |
| label=f"{op_label} (Wgrad)", |
| * **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps. | ||
| Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas) | ||
| are summed, giving a kernel-only measurement. | ||
|
|
There was a problem hiding this comment.
Docstring lists "cublas" but the pattern tuple uses "gemm" instead
The module docstring (line 19) lists the matched kernel patterns as (nvjet, xmma, cutlass, cublas), but GEMM_KERNEL_PATTERNS at line 70 is ("gemm", "nvjet", "xmma", "cutlass") — "cublas" is absent and "gemm" was added in its place. In practice "gemm" does catch cuBLAS kernels (their names contain gemm), so the behaviour is correct, but the docstring is inaccurate and may confuse users auditing kernel coverage.
| * **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps. | |
| Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas) | |
| are summed, giving a kernel-only measurement. | |
| * **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps. | |
| Only the matched GEMM compute kernels (gemm, nvjet, xmma, cutlass) | |
| are summed, giving a kernel-only measurement. |
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Description
Adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool. The guide
explains how to derive all 12 per-layer GEMM shapes (Fprop, Dgrad, Wgrad) from transformer model
hyperparameters, benchmark them across precisions (BF16, FP8 Block, MXFP8, NVFP4), and interpret the resulting
speedup estimates.
The benchmark tool supports two modes: model config mode (derives shapes automatically from hidden_size,
intermediate_size, etc.) and manual shape mode (explicit MxKxN triplets). It measures both autocast performance
(realistic end-to-end with quantization overhead) and pre-quantized kernel-only throughput, using CUDA events
or torch.profiler timing backends.
Type of change
Changes
Add benchmarks/gemm/benchmark_gemm.py — standalone GEMM benchmark tool supporting BF16, FP8 Block, MXFP8, and
NVFP4 precisions with autocast and pre-quantized modes, CUDA event and torch.profiler timing, Nsight Systems
integration, and bar-chart output
Add docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst — documentation covering GEMM
shape derivation from model configs, forward/backward pass shape conventions, precision mapping per GEMM pass,
speedup calculation methodology, and a worked example on B300
Add benchmark result plots (img/model_config_speedup.png, img/model_config_speedup_prequant.png)
Update docs/features/low_precision_training/index.rst toctree to include the new guide
Please list the changes introduced in this PR:
Change A
Change B
Checklist: