Skip to content

[JAX] Grouped quant+GEMM custom partitioning rules#3058

Draft
jberchtold-nvidia wants to merge 4 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-custom-partition-rules
Draft

[JAX] Grouped quant+GEMM custom partitioning rules#3058
jberchtold-nvidia wants to merge 4 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-custom-partition-rules

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft May 28, 2026 20:59
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR adds JAX custom partitioning rules (partition and shardy_sharding_rule) to GroupedGemmPrimitive and GroupedQuantizePrimitive, enabling Expert Parallelism (EP) + FSDP sharding for grouped quant+GEMM operations outside of shard_map. It also removes the "FSDP not supported" assertion from _grouped_dense_bwd_rule and wires up the actual allgather/reduce-scatter logic for FSDP in the backward pass.

  • gemm.py / quantization.py: Partition spec helpers propagate EP/FSDP axes and compute local sizes; GroupedGemmPrimitive.partition handles FSDP all-gather on the RHS weight and optional psum on the output.
  • dense.py: Lifts the FSDP assertion, adds a _is_manual_mesh_axis guard, and adds a with_sharding_constraint on wgrad to propagate the correct EP+FSDP spec.
  • module.py / sharding.py: Contracting dims switched to (-1,) for 3-D EP-leading input support; ep_resource added to MeshResource.

Confidence Score: 3/5

The core partitioning logic is new and complex; the backward FSDP path in dense.py has a silent EP-axis overwrite that could produce an incorrect sharding constraint in an edge case, and the 1-D size heuristic in gemm.py assumes a specific sharding orientation without enforcing it.

The EP-axis overwrite in _grouped_dense_bwd_rule silently produces a wrong sharding constraint when kernel_fsdp_axis_idx == 0, and the _local_2d_sizes_from_spec heuristic could compute incorrect local sizes if sharding falls on the right dimension. Both issues are in newly-written paths lacking targeted test coverage.

transformer_engine/jax/dense.py (backward FSDP wgrad sharding constraint) and transformer_engine/jax/cpp_extensions/gemm.py (_local_2d_sizes_from_spec fallback logic) deserve a close second look.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Adds spec-manipulation helpers and GroupedGemmPrimitive.partition / shardy_sharding_rule; helper functions are duplicated with quantization.py and the _local_2d_sizes_from_spec fallback may divide the wrong 2D size.
transformer_engine/jax/cpp_extensions/quantization.py Adds GroupedQuantizePrimitive.partition and shardy_sharding_rule; logic looks correct for the supported MXFP8/tensor-scaling paths.
transformer_engine/jax/dense.py Removes the FSDP assertion and implements backward FSDP allgather/reduce-scatter; has a silent EP-axis overwrite when kernel_fsdp_axis_idx == 0.
transformer_engine/jax/flax/module.py Contracting dims switched from ((1,), (1,)) to ((-1,), (1,)) to support the new 3-D EP reshape path.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource dataclass; trivial, safe change.
tests/jax/test_grouped_gemm_partitioning.py New unit tests for partition/shardy spec correctness and an end-to-end MXFP8 EP+FSDP test on a 1x1 mesh.
tests/jax/test_multi_process_distributed_grouped_gemm.py Switches to MXFP8_1D_SCALING, fixes the pre-existing jnp.allclose no-assert bug, adds a new helper that references a module-level mesh global.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Compiler
    participant QP as GroupedQuantizePrimitive.partition
    participant GP as GroupedGemmPrimitive.partition
    participant SI as sharded_impl (per device)

    JAX->>QP: (x, scale, group_sizes) with EP+FSDP specs
    QP->>QP: _parse_partition_specs() derive flat/group output specs
    QP-->>JAX: arg_shardings, out_shardings, sharded_impl
    JAX->>SI: local x shard
    SI->>SI: GroupedQuantizePrimitive.impl()
    SI->>SI: _pad_or_slice_to_shape(scale_inv, local_shape)
    SI->>SI: pmax amax over dp/fsdp
    SI-->>JAX: rowwise_out, colwise_out, scale_invs, amax

    JAX->>GP: lhs, rhs_weight with EP+FSDP specs
    GP->>GP: inject EP into rhs dim-0 if missing
    GP->>GP: strip FSDP from rhs gather_rhs_fsdp
    GP-->>JAX: arg_shardings, out_sharding, sharded_impl
    JAX->>SI: local lhs/rhs shards
    SI->>SI: GroupedGemmPrimitive.impl() local GEMM
    SI->>SI: psum(out, reduce_axis) if needed
    SI-->>JAX: out

    Note over JAX,SI: Backward pass inside shard_map only
    JAX->>SI: _grouped_dense_bwd_rule
    SI->>SI: _is_manual_mesh_axis check
    alt FSDP axis is manual
        SI->>SI: allgather or psum dgrad
        SI->>SI: psum_scatter or psum wgrad
    end
    SI->>SI: with_sharding_constraint wgrad EP+FSDP spec
Loading

Comments Outside Diff (2)

  1. tests/jax/test_multi_process_distributed_grouped_gemm.py, line 427-429 (link)

    P2 Function implicitly depends on module-level mesh variable

    run_grouped_dense_mxfp8_ep_fsdp_outside_shard_map references mesh as a free variable. It is only defined in the if __name__ == "__main__": block (line 243), which assigns it to the module global scope. If this function is ever renamed to start with test_ (or called from any other context), it will raise NameError: name 'mesh' is not defined. Consider accepting mesh as a parameter to make the dependency explicit.

  2. transformer_engine/jax/cpp_extensions/gemm.py, line 601-621 (link)

    P2 Ambiguous heuristic when both 2D sizes are divisible by the shard count

    When the tensor is 1-D, _local_2d_sizes_from_spec picks left_size // spec_size whenever left_size % spec_size == 0, even if the actual sharding is along the right/hidden dimension. A mismatch would silently produce incorrect local sizes passed to the GEMM primitive. A comment documenting the assumption (EP always shards the left/token dimension) would at least make the contract explicit.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +579 to +582
if len(wgrad_spec) > 0:
wgrad_spec[0] = ep_resource
if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec):
wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 EP axis silently dropped when kernel_fsdp_axis_idx == 0

wgrad_spec[0] is first set to ep_resource and then immediately overwritten by kernel_fsdp_mesh_axis when kernel_fsdp_axis_idx == 0. The sequential assignments silently discard the EP constraint in that edge case. While kernel_fsdp_axis_idx = 0 is uncommon, a ValueError would be safer than silent data loss.

Suggested change
if len(wgrad_spec) > 0:
wgrad_spec[0] = ep_resource
if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec):
wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis
if len(wgrad_spec) > 0:
wgrad_spec[0] = ep_resource
if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec):
if kernel_fsdp_axis_idx == 0 and ep_resource is not None and ep_resource != kernel_fsdp_mesh_axis:
raise ValueError(
f"kernel_fsdp_axis_idx=0 conflicts with ep_resource={ep_resource!r}; "
f"FSDP and EP cannot share the same kernel dimension."
)
wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis

Comment on lines +277 to +291
def _local_shape_from_spec(global_shape, spec, mesh):
local_shape = []
for dim, axis_spec in zip(global_shape, spec):
axis_size = _axis_spec_size(axis_spec, mesh)
local_shape.append(dim // axis_size)
return tuple(local_shape)


def _axis_spec_size(axis_spec, mesh):
axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,)
axis_size = 1
for axis in axis_tuple:
if axis is not None:
axis_size *= mesh.shape[axis]
return axis_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Utility functions duplicated between gemm.py and quantization.py

_axis_spec_size and _local_shape_from_spec are defined identically in both files. These should live in a shared module and be imported by both callers to avoid silent drift.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant