Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4072,6 +4072,7 @@ def attn_forward_func_with_cp(
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a",
"a2a+p2p",
], f"Context parallelism does not support MLA with {cp_comm_type=}!"
Comment on lines 4072 to 4077
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No tests added for new MLA + a2a combination

The PR author acknowledges in the checklist that tests have not been added. The new "a2a" entry in the MLA allowlist enables a code path (AttnFuncWithCPAndQKVOA2A + MLA) that has not been exercised by any automated test. While the implementation appears structurally sound (the A2A communication handles q, k, v independently so asymmetric head dims are preserved), the lack of coverage means regressions in forward pass accuracy, backward pass gradients, or FP8-quantized variants could go undetected.

Consider adding a test case (similar to existing CP x MLA tests) that covers:

  • cp_comm_type="a2a" with k.shape[-1] != v.shape[-1]
  • Both use_fused_attention=True and use_fused_attention=False variants
  • Gradient correctness (torch.autograd.gradcheck or compare against a non-CP reference)


Expand Down