Adds GEMM Profiling Guide to TE#2863
Conversation
Greptile SummaryThis PR adds a standalone GEMM benchmarking tool (
Confidence Score: 4/5Safe to merge with one fix: the MXFP8 autocast benchmark will crash on Hopper if the user omits --no-fp8. The MXFP8 autocast function (benchmark_fp8) has no hardware guard and no exception handler, unlike every other non-BF16 benchmark in the file. On Hopper (SM90), invoking the tool without --no-fp8 will attempt MXFP8BlockScaling and, if TE rejects it, crash without any diagnostic message. The pre-quantized MXFP8 path and the NVFP4 paths are both protected; this one gap is the only blocking issue in an otherwise clean addition. benchmarks/gemm/benchmark_gemm.py — specifically the benchmark_fp8 function and the run_fp8 flag computation in both orchestrators. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[CLI args] --> B{Model config args?}
B -- Yes --> C[run_model_config_benchmarks]
B -- No --> D[run_benchmarks\nshape mode]
C --> E[compute_gemm_shapes\nFprop / Dgrad / Wgrad]
E --> F[_benchmark_single_shape\nfor each shape]
D --> F2[shape loop\nfor each MxKxN]
F --> G{Precision enabled?}
F2 --> G
G --> H[benchmark_bf16]
G --> I[benchmark_fp8_current]
G --> J[benchmark_fp8_delayed]
G --> K[benchmark_fp8_block]
G --> L[benchmark_fp8 ⚠️\nno hardware guard]
G --> M{is_blackwell_available?}
M -- Yes --> N[benchmark_fp4]
M -- No --> O[skip NVFP4]
L -->|No guard on non-Blackwell| P[⚠️ May crash on Hopper]
H & I & J & K & N --> Q[GEMMResult]
Q --> R[Print table + speedup summary]
R --> S[create_model_config_plot or create_plot]
Reviews (12): Last reviewed commit: "Apply suggestion from @pggPL" | Re-trigger Greptile |
|
Hi @jomitchellnv, I see that this PR is open, but "Documentation" job is failing. If you fix it, please ping me and I'll review it. |
|
@pggPL they should be fixed now I hope |
|
/te-ci L1 pytorch |
64c353d to
88a9e0b
Compare
pggPL
left a comment
There was a problem hiding this comment.
I'm super happy about this change, I think we really need that.
I added some comments.
|
One more thought - we should consider adding MoE - grouped gemm. |
pggPL
left a comment
There was a problem hiding this comment.
I'm ok with the tutorial, but the speedups part needs polishing
|
i need to grab an H200 and B300 node to regenerate the plots then it should be ok |
f8ebbd2 to
36cc7fa
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Benchmark tool: - Always benchmark Dgrad separately (remove --verify-dgrad flag) - Pass measured Dgrad data to plot instead of 2x Fprop approximation - Add FP8 CurrentScaling and DelayedScaling benchmark support - Add FP8Block to shape mode (was missing, only in model-config mode) - Add --no-fp8-current and --no-fp8-delayed CLI flags Documentation: - Restructure: concise speedups.rst in features/, full tutorial in examples/ - Add device-specific precision recipes (Hopper vs Blackwell) - Add Hopper (H200) benchmark results alongside Blackwell (B300) - Remove misleading FP8 Block vs MXFP8 comparison (different target devices) - Rename "How Shapes Are Derived" to appendix, promote key sections - Convert benchmark tool references to GitHub links - Refresh all benchmark numbers with FP8 Current/Delayed columns Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-1979.ipp2a2.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
- Define autocast vs pre-quantized modes upfront before the figures - Remove the --pre-quantize flag reference and the standalone note - Replace unclear quantization-overhead jargon with plain language - Condense the verbose "Speedup Is Shape-Dependent" section - Reword "Fprop vs Dgrad comparisons" to per-operation breakdowns - Fix benchmark_gemm.py: skip FP8 DelayedScaling in pre-quantized mode (it has no pre-quantized variant and silently fell back to the autocast path, producing a misleading bar in the pre-quantized plots) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Re-ran the model-config benchmark on B300 (SM100) and H200 (SM90) with the pre-quantized DelayedScaling fix applied, and synced the numbers in speedups.rst: - B300 autocast: now includes FP8Block (1.30x); FP8Current 1.41x, FP8Delayed 1.61x, MXFP8 1.44x, NVFP4 2.03x - B300 pre-quantized: FP8Delayed bar removed, FP8Block (1.82x) added; NVFP4 3.55x - H200 autocast: FP8Current 1.57x, FP8Delayed 1.69x, FP8Block 1.41x - H200 pre-quantized: FP8Delayed removed; FP8Block dropped (no Hopper prequant support); raw FP8 1.92x Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
66575ac to
1a16e4a
Compare
| def benchmark_fp8( | ||
| M: int, | ||
| K: int, | ||
| N: int, | ||
| num_warmup: int = 10, | ||
| num_iters: int = 100, | ||
| timing: str = "cuda-events", | ||
| verbose: bool = False, | ||
| ) -> Optional[GEMMResult]: | ||
| """MXFP8 GEMM via te.Linear autocast.""" | ||
| if not TE_AVAILABLE: | ||
| return None | ||
|
|
||
| device = torch.device("cuda") | ||
| flops = compute_gemm_flops(M, K, N) | ||
|
|
||
| linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) | ||
| x = torch.randn(M, K, dtype=torch.bfloat16, device=device) | ||
| recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) | ||
|
|
||
| with te.autocast(enabled=True, recipe=recipe): | ||
| for _ in range(num_warmup): | ||
| linear(x) | ||
| torch.cuda.synchronize() | ||
|
|
||
| def _run(): | ||
| linear(x) | ||
|
|
||
| if timing == "profiler": | ||
| tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) | ||
| else: | ||
| lin_lg = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) | ||
| x_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) | ||
| tflops, avg_ms = _time_with_cuda_events( | ||
| _run, num_iters, flops, leading_fn=lambda: lin_lg(x_lg) | ||
| ) | ||
| del lin_lg, x_lg | ||
|
|
||
| return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="MXFP8") |
There was a problem hiding this comment.
MXFP8 autocast path has no hardware guard or exception handler
benchmark_fp8 (MXFP8 autocast) contains no is_blackwell_available() early-return and no try/except block. In the default mode (no --pre-quantize), both run_benchmarks and run_model_config_benchmarks select this function via fp8_fn = ... if pre_quantize else benchmark_fp8. If a Hopper user forgets --no-fp8, the call to te.autocast(enabled=True, recipe=MXFP8BlockScaling(...)) will likely raise, propagating uncaught through the loop and crashing the entire tool.
The documentation explicitly states MXFP8 requires Blackwell, and only NVFP4 has a corresponding guard — benchmark_fp4 returns None early when not is_blackwell_available(), and the orchestrators print a skip notice. The pre-quantized MXFP8 path (benchmark_fp8_prequantized) also has a try/except, so only the autocast variant is unprotected. Adding either an is_blackwell_available() guard at the top of this function or a try/except with a return None fallback would make it consistent with the rest of the codebase.
Description
Adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool. The guide
explains how to derive all 12 per-layer GEMM shapes (Fprop, Dgrad, Wgrad) from transformer model
hyperparameters, benchmark them across precisions (BF16, FP8 Block, MXFP8, NVFP4), and interpret the resulting
speedup estimates.
The benchmark tool supports two modes: model config mode (derives shapes automatically from hidden_size,
intermediate_size, etc.) and manual shape mode (explicit MxKxN triplets). It measures both autocast performance
(realistic end-to-end with quantization overhead) and pre-quantized kernel-only throughput, using CUDA events
or torch.profiler timing backends.
Type of change
Changes
Add benchmarks/gemm/benchmark_gemm.py — standalone GEMM benchmark tool supporting BF16, FP8 Block, MXFP8, and
NVFP4 precisions with autocast and pre-quantized modes, CUDA event and torch.profiler timing, Nsight Systems
integration, and bar-chart output
Add docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst — documentation covering GEMM
shape derivation from model configs, forward/backward pass shape conventions, precision mapping per GEMM pass,
speedup calculation methodology, and a worked example on B300
Add benchmark result plots (img/model_config_speedup.png, img/model_config_speedup_prequant.png)
Update docs/features/low_precision_training/index.rst toctree to include the new guide
Please list the changes introduced in this PR:
Change A
Change B
Checklist: