Skip to content

Add shard exp on fsdp custom mesh#3988

Open
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-fsdp-exp
Open

Add shard exp on fsdp custom mesh#3988
NuojCheng wants to merge 1 commit into
mainfrom
chengnuojin-fsdp-exp

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented May 27, 2026

Description

This PR introduces a new custom mesh and rule that enabling FSDP sharding exp dimension for weights in MoE component. When non-elementwise optimization is enabled, e.g. Muon, this custom mesh show better benefits.

Compared with previous implementation of shard_exp_on_fsdp, this custom mesh support mixture of FSDP and EP.

Use custom_mesh_and_rule=shard-exp-on-fsdp to enable.

Tests

Performance Regression

Tpu7x-8 on DSv2-16b, the losses of 10 steps perfectly match and performance are similar, see https://diff.googleplex.com/#key=r2vyDl7870kN.

Support on FSDP + EP

Support on explicit sharding

Performance improvement on Muon optimizer

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 27, 2026

Codecov Report

❌ Patch coverage is 70.00000% with 9 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/sharding.py 62.50% 3 Missing and 3 partials ⚠️
src/maxtext/layers/moe.py 76.92% 1 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-fsdp-exp branch 6 times, most recently from 0331b14 to 11f85f7 Compare May 27, 2026 21:32
Comment thread src/maxtext/layers/moe.py
return lhs_quantize_dtype, rhs_quantize_dtype

def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do weight_gather_axes = []

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it first but got some pylint warning..


def remove_fsdp_pspec(pspec):
"""Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec."""
if isinstance(pspec, jax.sharding.PartitionSpec):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in what scenario is it not a jax.sharding.PartitionSpec type?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in some cases it might be none. Since it is a shared funciton in sharding.py, I tried to make this funciton as general-purpose as possible.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, can we change it to:

if pspec == None:
return psepc
new_spec = []
...

Comment thread src/maxtext/layers/moe.py
w1_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = self._logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
# Update kernel pspec for FSDP AG
w0_pspec = remove_fsdp_pspec(w0_pspec)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding, why remove fsdp from pspec?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sparse matmul wrapper function, we FSDP all gather weights before starting the shard map. I think it is just a decision made earlier.

@github-actions
Copy link
Copy Markdown

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @NuojCheng, but I was unable to process your request. Please see the logs for more details.


def remove_fsdp_pspec(pspec):
"""Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec."""
if isinstance(pspec, jax.sharding.PartitionSpec):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, can we change it to:

if pspec == None:
return psepc
new_spec = []
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants