Skip to content

Adds GEMM Profiling Guide to TE#2863

Open
jomitchellnv wants to merge 8 commits into
NVIDIA:mainfrom
jomitchellnv:jm/gemm-blog
Open

Adds GEMM Profiling Guide to TE#2863
jomitchellnv wants to merge 8 commits into
NVIDIA:mainfrom
jomitchellnv:jm/gemm-blog

Conversation

@jomitchellnv
Copy link
Copy Markdown
Contributor

Description

Adds a GEMM profiling guide to the Transformer Engine documentation and a companion benchmark tool. The guide
explains how to derive all 12 per-layer GEMM shapes (Fprop, Dgrad, Wgrad) from transformer model
hyperparameters, benchmark them across precisions (BF16, FP8 Block, MXFP8, NVFP4), and interpret the resulting
speedup estimates.

The benchmark tool supports two modes: model config mode (derives shapes automatically from hidden_size,
intermediate_size, etc.) and manual shape mode (explicit MxKxN triplets). It measures both autocast performance
(realistic end-to-end with quantization overhead) and pre-quantized kernel-only throughput, using CUDA events
or torch.profiler timing backends.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add benchmarks/gemm/benchmark_gemm.py — standalone GEMM benchmark tool supporting BF16, FP8 Block, MXFP8, and
    NVFP4 precisions with autocast and pre-quantized modes, CUDA event and torch.profiler timing, Nsight Systems
    integration, and bar-chart output

  • Add docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst — documentation covering GEMM
    shape derivation from model configs, forward/backward pass shape conventions, precision mapping per GEMM pass,
    speedup calculation methodology, and a worked example on B300

  • Add benchmark result plots (img/model_config_speedup.png, img/model_config_speedup_prequant.png)

  • Update docs/features/low_precision_training/index.rst toctree to include the new guide
    Please list the changes introduced in this PR:

  • Change A

  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jomitchellnv jomitchellnv changed the title adds blog post Adds GEMM Profiling Guide to TE Apr 9, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 9, 2026

Greptile Summary

This PR adds a standalone GEMM benchmarking tool (benchmarks/gemm/benchmark_gemm.py) and accompanying documentation — a tutorial RST, a speedups overview page, and B300/H200 result plots — to help users measure and interpret per-layer GEMM performance across BF16, FP8, MXFP8, and NVFP4 precisions.

  • benchmark_gemm.py: 1,883-line benchmark supporting model-config mode (derives all 12 GEMM shapes from hyperparameters) and manual shape mode, with CUDA-events and torch.profiler timing backends, Nsight Systems integration, and bar-chart output.
  • docs/examples/gemm_profiling/gemm_profiling.rst: End-to-end tutorial covering shape derivation, forward/backward conventions, worked B300 and H200 examples, and guidance on interpreting speedups.
  • docs/features/low_precision_training/speedups.rst + index update: Adds a high-level speedups overview page to the low-precision training section, cross-linked to the full tutorial.

Confidence Score: 4/5

Safe to merge with one fix: the MXFP8 autocast benchmark will crash on Hopper if the user omits --no-fp8.

The MXFP8 autocast function (benchmark_fp8) has no hardware guard and no exception handler, unlike every other non-BF16 benchmark in the file. On Hopper (SM90), invoking the tool without --no-fp8 will attempt MXFP8BlockScaling and, if TE rejects it, crash without any diagnostic message. The pre-quantized MXFP8 path and the NVFP4 paths are both protected; this one gap is the only blocking issue in an otherwise clean addition.

benchmarks/gemm/benchmark_gemm.py — specifically the benchmark_fp8 function and the run_fp8 flag computation in both orchestrators.

Important Files Changed

Filename Overview
benchmarks/gemm/benchmark_gemm.py New 1,883-line GEMM benchmark supporting BF16/FP8/MXFP8/NVFP4; FP8Block shape-mode fix confirmed, speedup loop fixed, but benchmark_fp8 (MXFP8 autocast) lacks Blackwell guard and try/except, crashing on Hopper without --no-fp8.
docs/examples/gemm_profiling/gemm_profiling.rst New tutorial RST (589 lines); H200 speedup block lists precisions in FP8Delayed-first order while the code emits FP8Current first; B300 example output still carries stale speedup lines (previously flagged).
docs/features/low_precision_training/speedups.rst New 114-line overview page with B300/H200 benchmark tabs; correctly linked in the low_precision_training toctree and uses existing sphinx-tabs extension.
docs/features/low_precision_training/index.rst Adds speedups.rst to the toctree and fixes a missing trailing newline; straightforward change.
docs/index.rst Adds examples/gemm_profiling/gemm_profiling.rst to the root toctree, making the tutorial reachable from the Sphinx navigation.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[CLI args] --> B{Model config args?}
    B -- Yes --> C[run_model_config_benchmarks]
    B -- No --> D[run_benchmarks\nshape mode]
    C --> E[compute_gemm_shapes\nFprop / Dgrad / Wgrad]
    E --> F[_benchmark_single_shape\nfor each shape]
    D --> F2[shape loop\nfor each MxKxN]
    F --> G{Precision enabled?}
    F2 --> G
    G --> H[benchmark_bf16]
    G --> I[benchmark_fp8_current]
    G --> J[benchmark_fp8_delayed]
    G --> K[benchmark_fp8_block]
    G --> L[benchmark_fp8 ⚠️\nno hardware guard]
    G --> M{is_blackwell_available?}
    M -- Yes --> N[benchmark_fp4]
    M -- No --> O[skip NVFP4]
    L -->|No guard on non-Blackwell| P[⚠️ May crash on Hopper]
    H & I & J & K & N --> Q[GEMMResult]
    Q --> R[Print table + speedup summary]
    R --> S[create_model_config_plot or create_plot]
Loading

Reviews (12): Last reviewed commit: "Apply suggestion from @pggPL" | Re-trigger Greptile

Comment thread benchmarks/gemm/benchmark_gemm.py Outdated
Comment thread benchmarks/gemm/benchmark_gemm.py
Comment thread benchmarks/gemm/benchmark_gemm.py
@pggPL pggPL self-requested a review April 10, 2026 14:00
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented Apr 13, 2026

Hi @jomitchellnv, I see that this PR is open, but "Documentation" job is failing. If you fix it, please ping me and I'll review it.

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

@pggPL they should be fixed now I hope

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

/te-ci L1 pytorch

@jomitchellnv jomitchellnv force-pushed the jm/gemm-blog branch 2 times, most recently from 64c353d to 88a9e0b Compare May 1, 2026 18:02
Comment thread benchmarks/gemm/benchmark_gemm.py
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm super happy about this change, I think we really need that.

I added some comments.

Comment thread docs/features/low_precision_training/index.rst Outdated
Comment thread docs/examples/gemm_profiling/gemm_profiling.rst
Comment thread benchmarks/gemm/benchmark_gemm.py
Comment thread docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst Outdated
Comment thread docs/examples/gemm_profiling/gemm_profiling.rst
Comment thread docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst Outdated
Comment thread docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst Outdated
Comment thread benchmarks/gemm/benchmark_gemm.py Outdated
Comment thread docs/features/low_precision_training/gemm_profiling/img/model_config_speedup.png Outdated
Comment thread docs/features/low_precision_training/gemm_profiling/gemm_profiling.rst Outdated
@pggPL
Copy link
Copy Markdown
Collaborator

pggPL commented May 4, 2026

One more thought - we should consider adding MoE - grouped gemm.

Comment thread benchmarks/gemm/benchmark_gemm.py
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments

Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/examples/gemm_profiling/gemm_profiling.rst Outdated
Comment thread docs/examples/gemm_profiling/gemm_profiling.rst Outdated
Comment thread benchmarks/gemm/benchmark_gemm.py
Comment thread docs/examples/gemm_profiling/gemm_profiling.rst
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with the tutorial, but the speedups part needs polishing

Comment thread docs/features/low_precision_training/speedups.rst
Comment thread docs/features/low_precision_training/speedups.rst
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some comments

@jomitchellnv
Copy link
Copy Markdown
Contributor Author

i need to grab an H200 and B300 node to regenerate the plots then it should be ok

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
Comment thread docs/features/low_precision_training/speedups.rst Outdated
Copy link
Copy Markdown
Collaborator

@pggPL pggPL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Jonathan Mitchell and others added 5 commits May 29, 2026 11:09
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
  Benchmark tool:
  - Always benchmark Dgrad separately (remove --verify-dgrad flag)
  - Pass measured Dgrad data to plot instead of 2x Fprop approximation
  - Add FP8 CurrentScaling and DelayedScaling benchmark support
  - Add FP8Block to shape mode (was missing, only in model-config mode)
  - Add --no-fp8-current and --no-fp8-delayed CLI flags

  Documentation:
  - Restructure: concise speedups.rst in features/, full tutorial in examples/
  - Add device-specific precision recipes (Hopper vs Blackwell)
  - Add Hopper (H200) benchmark results alongside Blackwell (B300)
  - Remove misleading FP8 Block vs MXFP8 comparison (different target devices)
  - Rename "How Shapes Are Derived" to appendix, promote key sections
  - Convert benchmark tool references to GitHub links
  - Refresh all benchmark numbers with FP8 Current/Delayed columns

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-1979.ipp2a2.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@dl325g11-0771.ipp4a1.colossus.nvidia.com>
jomitchellnv and others added 3 commits May 29, 2026 11:09
- Define autocast vs pre-quantized modes upfront before the figures
- Remove the --pre-quantize flag reference and the standalone note
- Replace unclear quantization-overhead jargon with plain language
- Condense the verbose "Speedup Is Shape-Dependent" section
- Reword "Fprop vs Dgrad comparisons" to per-operation breakdowns
- Fix benchmark_gemm.py: skip FP8 DelayedScaling in pre-quantized mode
  (it has no pre-quantized variant and silently fell back to the
  autocast path, producing a misleading bar in the pre-quantized plots)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Re-ran the model-config benchmark on B300 (SM100) and H200 (SM90) with the
pre-quantized DelayedScaling fix applied, and synced the numbers in speedups.rst:

- B300 autocast: now includes FP8Block (1.30x); FP8Current 1.41x, FP8Delayed
  1.61x, MXFP8 1.44x, NVFP4 2.03x
- B300 pre-quantized: FP8Delayed bar removed, FP8Block (1.82x) added; NVFP4 3.55x
- H200 autocast: FP8Current 1.57x, FP8Delayed 1.69x, FP8Block 1.41x
- H200 pre-quantized: FP8Delayed removed; FP8Block dropped (no Hopper prequant
  support); raw FP8 1.92x

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com>
Comment on lines +430 to +468
def benchmark_fp8(
M: int,
K: int,
N: int,
num_warmup: int = 10,
num_iters: int = 100,
timing: str = "cuda-events",
verbose: bool = False,
) -> Optional[GEMMResult]:
"""MXFP8 GEMM via te.Linear autocast."""
if not TE_AVAILABLE:
return None

device = torch.device("cuda")
flops = compute_gemm_flops(M, K, N)

linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device)
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)

with te.autocast(enabled=True, recipe=recipe):
for _ in range(num_warmup):
linear(x)
torch.cuda.synchronize()

def _run():
linear(x)

if timing == "profiler":
tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose)
else:
lin_lg = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device)
x_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device)
tflops, avg_ms = _time_with_cuda_events(
_run, num_iters, flops, leading_fn=lambda: lin_lg(x_lg)
)
del lin_lg, x_lg

return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="MXFP8")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 MXFP8 autocast path has no hardware guard or exception handler

benchmark_fp8 (MXFP8 autocast) contains no is_blackwell_available() early-return and no try/except block. In the default mode (no --pre-quantize), both run_benchmarks and run_model_config_benchmarks select this function via fp8_fn = ... if pre_quantize else benchmark_fp8. If a Hopper user forgets --no-fp8, the call to te.autocast(enabled=True, recipe=MXFP8BlockScaling(...)) will likely raise, propagating uncaught through the loop and crashing the entire tool.

The documentation explicitly states MXFP8 requires Blackwell, and only NVFP4 has a corresponding guard — benchmark_fp4 returns None early when not is_blackwell_available(), and the orchestrators print a skip notice. The pre-quantized MXFP8 path (benchmark_fp8_prequantized) also has a try/except, so only the autocast variant is unprotected. Adding either an is_blackwell_available() guard at the top of this function or a try/except with a return None fallback would make it consistent with the rest of the codebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants