Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces "TPSP mix mode" to integrate sequence parallelism with tensor parallelism, optimizing microbatch overlap performance. It refactors BaseLayerInfer with unified communication helpers and updates CudaGraph to support TP-aligned batching. Feedback points out a calculation error in microbatch_overlap_prefill regarding prefix tokens, a logically flawed assertion in the sp_pad_copy kernel that causes unavoidable failures, and a missing assignment of gathered data in Llama's post-inference. Additionally, caching environment flags is recommended to improve efficiency in hot inference paths.
| infer_handle_token_num0 = triton.cdiv(model_input0.total_token_num, self.tp_world_size_) * self.tp_world_size_ | ||
| infer_handle_token_num1 = triton.cdiv(model_input1.total_token_num, self.tp_world_size_) * self.tp_world_size_ |
There was a problem hiding this comment.
In microbatch_overlap_prefill, the calculation of infer_handle_token_num0 and infer_handle_token_num1 uses model_input.total_token_num instead of the actual handle token count (total_token_num - prefix_total_token_num). This is inconsistent with the logic in _prefill (line 485) and will lead to incorrect padding if prefix tokens are present, as the sequence parallelism split should be aligned based on the tokens being processed in the current forward pass.
| infer_handle_token_num0 = triton.cdiv(model_input0.total_token_num, self.tp_world_size_) * self.tp_world_size_ | |
| infer_handle_token_num1 = triton.cdiv(model_input1.total_token_num, self.tp_world_size_) * self.tp_world_size_ | |
| infer_handle_token_num0 = triton.cdiv(origin_handle_token_num0, self.tp_world_size_) * self.tp_world_size_ | |
| infer_handle_token_num1 = triton.cdiv(origin_handle_token_num1, self.tp_world_size_) * self.tp_world_size_ |
| assert ( | ||
| in_token_num % sp_world_size == 0 | ||
| ), f"in_token_num % sp_world_size != 0, in_token_num: {in_token_num}, sp_world_size: {sp_world_size}" |
There was a problem hiding this comment.
This assertion is logically flawed. If in_token_num % sp_world_size == 0, the function returns at line 50. Therefore, if execution reaches line 51, the condition in_token_num % sp_world_size == 0 is guaranteed to be false, causing the assertion to always fail. If the intention is to enforce alignment, the padding logic below (lines 54-73) becomes dead code and should be removed. If padding is intended to be supported, this assertion should be removed or corrected.
| ) | ||
| # len(infer_state.input_ids) 获取真实输入长度 | ||
| input_embdings = gather_data[0 : len(infer_state.input_ids)] | ||
| self._tpsp_allgather(input=input_embdings, infer_state=infer_state) |
There was a problem hiding this comment.
The result of self._tpsp_allgather is ignored. Since _tpsp_allgather returns a new tensor when gathering is performed, input_embdings must be updated with the return value to ensure the subsequent _token_forward call uses the gathered data.
| self._tpsp_allgather(input=input_embdings, infer_state=infer_state) | |
| input_embdings = self._tpsp_allgather(input=input_embdings, infer_state=infer_state) |
| raise Exception("need to impl") | ||
|
|
||
| def _tpsp_allgather(self, input: torch.Tensor, infer_state: InferStateInfo): | ||
| if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: |
There was a problem hiding this comment.
No description provided.