Add shard exp on fsdp custom mesh#3988
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
0331b14 to
11f85f7
Compare
11f85f7 to
7385158
Compare
| return lhs_quantize_dtype, rhs_quantize_dtype | ||
|
|
||
| def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes): | ||
| def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes=None): |
There was a problem hiding this comment.
can we do weight_gather_axes = []
There was a problem hiding this comment.
Tried it first but got some pylint warning..
|
|
||
| def remove_fsdp_pspec(pspec): | ||
| """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" | ||
| if isinstance(pspec, jax.sharding.PartitionSpec): |
There was a problem hiding this comment.
in what scenario is it not a jax.sharding.PartitionSpec type?
There was a problem hiding this comment.
in some cases it might be none. Since it is a shared funciton in sharding.py, I tried to make this funciton as general-purpose as possible.
There was a problem hiding this comment.
nit, can we change it to:
if pspec == None:
return psepc
new_spec = []
...
| w1_pspec = self._logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) | ||
| wo_pspec = self._logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose")) | ||
| # Update kernel pspec for FSDP AG | ||
| w0_pspec = remove_fsdp_pspec(w0_pspec) |
There was a problem hiding this comment.
For my own understanding, why remove fsdp from pspec?
There was a problem hiding this comment.
For the sparse matmul wrapper function, we FSDP all gather weights before starting the shard map. I think it is just a decision made earlier.
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @NuojCheng, but I was unable to process your request. Please see the logs for more details. |
|
|
||
| def remove_fsdp_pspec(pspec): | ||
| """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" | ||
| if isinstance(pspec, jax.sharding.PartitionSpec): |
There was a problem hiding this comment.
nit, can we change it to:
if pspec == None:
return psepc
new_spec = []
...
Description
This PR introduces a new custom mesh and rule that enabling FSDP sharding exp dimension for weights in MoE component. When non-elementwise optimization is enabled, e.g. Muon, this custom mesh show better benefits.
Compared with previous implementation of
shard_exp_on_fsdp, this custom mesh support mixture of FSDP and EP.Use
custom_mesh_and_rule=shard-exp-on-fsdpto enable.Tests
Performance Regression
Tpu7x-8 on DSv2-16b, the losses of 10 steps perfectly match and performance are similar, see https://diff.googleplex.com/#key=r2vyDl7870kN.
Support on FSDP + EP
Support on explicit sharding
Performance improvement on Muon optimizer
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.