Skip to content

Wzj tpsp#1269

Open
hiworldwzj wants to merge 25 commits intomainfrom
wzj_tpsp
Open

Wzj tpsp#1269
hiworldwzj wants to merge 25 commits intomainfrom
wzj_tpsp

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces "TPSP mix mode" to integrate sequence parallelism with tensor parallelism, optimizing microbatch overlap performance. It refactors BaseLayerInfer with unified communication helpers and updates CudaGraph to support TP-aligned batching. Feedback points out a calculation error in microbatch_overlap_prefill regarding prefix tokens, a logically flawed assertion in the sp_pad_copy kernel that causes unavoidable failures, and a missing assignment of gathered data in Llama's post-inference. Additionally, caching environment flags is recommended to improve efficiency in hot inference paths.

Comment on lines +673 to +674
infer_handle_token_num0 = triton.cdiv(model_input0.total_token_num, self.tp_world_size_) * self.tp_world_size_
infer_handle_token_num1 = triton.cdiv(model_input1.total_token_num, self.tp_world_size_) * self.tp_world_size_
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In microbatch_overlap_prefill, the calculation of infer_handle_token_num0 and infer_handle_token_num1 uses model_input.total_token_num instead of the actual handle token count (total_token_num - prefix_total_token_num). This is inconsistent with the logic in _prefill (line 485) and will lead to incorrect padding if prefix tokens are present, as the sequence parallelism split should be aligned based on the tokens being processed in the current forward pass.

Suggested change
infer_handle_token_num0 = triton.cdiv(model_input0.total_token_num, self.tp_world_size_) * self.tp_world_size_
infer_handle_token_num1 = triton.cdiv(model_input1.total_token_num, self.tp_world_size_) * self.tp_world_size_
infer_handle_token_num0 = triton.cdiv(origin_handle_token_num0, self.tp_world_size_) * self.tp_world_size_
infer_handle_token_num1 = triton.cdiv(origin_handle_token_num1, self.tp_world_size_) * self.tp_world_size_

Comment on lines +51 to +53
assert (
in_token_num % sp_world_size == 0
), f"in_token_num % sp_world_size != 0, in_token_num: {in_token_num}, sp_world_size: {sp_world_size}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion is logically flawed. If in_token_num % sp_world_size == 0, the function returns at line 50. Therefore, if execution reaches line 51, the condition in_token_num % sp_world_size == 0 is guaranteed to be false, causing the assertion to always fail. If the intention is to enforce alignment, the padding logic below (lines 54-73) becomes dead code and should be removed. If padding is intended to be supported, this assertion should be removed or corrected.

)
# len(infer_state.input_ids) 获取真实输入长度
input_embdings = gather_data[0 : len(infer_state.input_ids)]
self._tpsp_allgather(input=input_embdings, infer_state=infer_state)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result of self._tpsp_allgather is ignored. Since _tpsp_allgather returns a new tensor when gathering is performed, input_embdings must be updated with the return value to ensure the subsequent _token_forward call uses the gathered data.

Suggested change
self._tpsp_allgather(input=input_embdings, infer_state=infer_state)
input_embdings = self._tpsp_allgather(input=input_embdings, infer_state=infer_state)

raise Exception("need to impl")

def _tpsp_allgather(self, input: torch.Tensor, infer_state: InferStateInfo):
if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling get_env_start_args() repeatedly in the hot path (inside _tpsp_allgather, _tpsp_reduce, and _tpsp_sp_split) for every layer is inefficient. It is better to cache the enable_tpsp_mix_mode flag in the __init__ method of BaseLayerInfer or pass it through the infer_state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant