[JAX] Grouped quant+GEMM custom partitioning rules#3058
[JAX] Grouped quant+GEMM custom partitioning rules#3058jberchtold-nvidia wants to merge 4 commits into
Conversation
…mm-custom-partition-rules
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds JAX custom partitioning rules (
Confidence Score: 3/5The core partitioning logic is new and complex; the backward FSDP path in dense.py has a silent EP-axis overwrite that could produce an incorrect sharding constraint in an edge case, and the 1-D size heuristic in gemm.py assumes a specific sharding orientation without enforcing it. The EP-axis overwrite in
Important Files Changed
Sequence DiagramsequenceDiagram
participant JAX as JAX Compiler
participant QP as GroupedQuantizePrimitive.partition
participant GP as GroupedGemmPrimitive.partition
participant SI as sharded_impl (per device)
JAX->>QP: (x, scale, group_sizes) with EP+FSDP specs
QP->>QP: _parse_partition_specs() derive flat/group output specs
QP-->>JAX: arg_shardings, out_shardings, sharded_impl
JAX->>SI: local x shard
SI->>SI: GroupedQuantizePrimitive.impl()
SI->>SI: _pad_or_slice_to_shape(scale_inv, local_shape)
SI->>SI: pmax amax over dp/fsdp
SI-->>JAX: rowwise_out, colwise_out, scale_invs, amax
JAX->>GP: lhs, rhs_weight with EP+FSDP specs
GP->>GP: inject EP into rhs dim-0 if missing
GP->>GP: strip FSDP from rhs gather_rhs_fsdp
GP-->>JAX: arg_shardings, out_sharding, sharded_impl
JAX->>SI: local lhs/rhs shards
SI->>SI: GroupedGemmPrimitive.impl() local GEMM
SI->>SI: psum(out, reduce_axis) if needed
SI-->>JAX: out
Note over JAX,SI: Backward pass inside shard_map only
JAX->>SI: _grouped_dense_bwd_rule
SI->>SI: _is_manual_mesh_axis check
alt FSDP axis is manual
SI->>SI: allgather or psum dgrad
SI->>SI: psum_scatter or psum wgrad
end
SI->>SI: with_sharding_constraint wgrad EP+FSDP spec
|
| if len(wgrad_spec) > 0: | ||
| wgrad_spec[0] = ep_resource | ||
| if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): | ||
| wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis |
There was a problem hiding this comment.
EP axis silently dropped when
kernel_fsdp_axis_idx == 0
wgrad_spec[0] is first set to ep_resource and then immediately overwritten by kernel_fsdp_mesh_axis when kernel_fsdp_axis_idx == 0. The sequential assignments silently discard the EP constraint in that edge case. While kernel_fsdp_axis_idx = 0 is uncommon, a ValueError would be safer than silent data loss.
| if len(wgrad_spec) > 0: | |
| wgrad_spec[0] = ep_resource | |
| if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): | |
| wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis | |
| if len(wgrad_spec) > 0: | |
| wgrad_spec[0] = ep_resource | |
| if 0 <= kernel_fsdp_axis_idx < len(wgrad_spec): | |
| if kernel_fsdp_axis_idx == 0 and ep_resource is not None and ep_resource != kernel_fsdp_mesh_axis: | |
| raise ValueError( | |
| f"kernel_fsdp_axis_idx=0 conflicts with ep_resource={ep_resource!r}; " | |
| f"FSDP and EP cannot share the same kernel dimension." | |
| ) | |
| wgrad_spec[kernel_fsdp_axis_idx] = kernel_fsdp_mesh_axis |
| def _local_shape_from_spec(global_shape, spec, mesh): | ||
| local_shape = [] | ||
| for dim, axis_spec in zip(global_shape, spec): | ||
| axis_size = _axis_spec_size(axis_spec, mesh) | ||
| local_shape.append(dim // axis_size) | ||
| return tuple(local_shape) | ||
|
|
||
|
|
||
| def _axis_spec_size(axis_spec, mesh): | ||
| axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) | ||
| axis_size = 1 | ||
| for axis in axis_tuple: | ||
| if axis is not None: | ||
| axis_size *= mesh.shape[axis] | ||
| return axis_size |
There was a problem hiding this comment.
Utility functions duplicated between
gemm.py and quantization.py
_axis_spec_size and _local_shape_from_spec are defined identically in both files. These should live in a shared module and be imported by both callers to avoid silent drift.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: