Skip to content

Conversation

@tsu-bin
Copy link

@tsu-bin tsu-bin commented Jan 16, 2026

The Context Parallel implementation of diffusers can't handle the case that sequence length not divisible by mesh size. This issue is universal for all models. Now take Qwen-Image as an example, before this PR:

  • test script
from accelerate import PartialState
from diffusers import QwenImagePipeline, ContextParallelConfig
import torch

torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
cpc = ContextParallelConfig(ulysses_degree=2)

pipe = QwenImagePipeline.from_pretrained(
    "/hostShare/models/Qwen-Image-Fp8-LoraFused/", 
    torch_dtype={"default": torch.bfloat16},
    use_safetensors=True
)
distributed_state = PartialState()
pipe.transformer.enable_parallelism(config=cpc)
pipe.to(distributed_state.device)

seed = torch.Generator().manual_seed(66)
positive_magic = ", Ultra HD, 4K, cinematic composition."
prompt = '''A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197".'''
image = pipe(
    prompt=prompt + positive_magic,
    width=1664,
    height=928,
    num_inference_steps=8,
    # true_cfg_scale=1.0,
    negative_prompt=None,
    generator=seed).images[0]
if rank == 0:
    image.save("cafa.png")
  • test cmd:
accelerate launch ./qwen_image_fp8_lora_fused_CP.py --num_processes=2
  • error log snippet:
[rank1]: Traceback (most recent call last):
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/xb_demo/qwen_image_fp8_lora_fused_CP.py", line 28, in <module>
[rank1]:     image = pipe(
[rank1]:             ^^^^^
[rank1]:   File "/hostShare/env_dit0/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py", line 686, in __call__
[rank1]:     noise_pred = self.transformer(
[rank1]:                  ^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/env_dit0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/env_dit0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/models/transformers/transformer_qwenimage.py", line 962, in forward
[rank1]:     image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
[rank1]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/env_dit0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/env_dit0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank1]:     return function_reference.post_forward(module, output)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/hooks/context_parallel.py", line 199, in post_forward
[rank1]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/hooks/context_parallel.py", line 211, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/hostShare/diffusers_all/diffusers_main/src/diffusers/hooks/context_parallel.py", line 261, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
  0%|                                                                                                                                                                                 | 0/8 [00:00<?, ?it/s]

The design of this bugfix is rather simple, yet robust and performant:

  • exchange tensor shape info before all_to_all exchange of actual tensor data
  • use pre-allocated tensors to cache the tensor shape info to avoid subsequent redundant tensor shape exchange, please note that we only exchange tensor shape once for each request, this is the only and minimal overhead introduced
  • with the tensor shape info we can all_to_all exchange tensor data with different shape

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant