Skip to content
32 changes: 32 additions & 0 deletions docs/source/en/api/pipelines/qwenimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ image.save("qwen_fewsteps.png")

With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.

### Single prompt with multiple reference images

```py
import torch
from PIL import Image
Expand All @@ -114,6 +116,36 @@ image = pipe(
).images[0]
```

### Batch processing with multiple prompts

The pipeline also supports batch processing where you can edit multiple images with different prompts simultaneously. Use a nested list format `[[img1], [img2]]` to provide input images for each prompt:

```py
import torch
from diffusers import QwenImageEditPlusPipeline
from diffusers.utils import load_image

pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
).to("cuda")

# Load input images
mountain_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mountain.jpg")

# Process two different edits in a single batch
images = pipe(
image=[[mountain_image], [mountain_image]], # Nested list for batch_size=2
prompt=[
"Transform into a sunset scene with warm orange and pink sky",
"Add snow and make it a winter scene"
],
num_inference_steps=50
).images

# images[0] contains the sunset version
# images[1] contains the winter version
```

## Performance

### torch.compile
Expand Down
155 changes: 107 additions & 48 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,32 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

def _preprocess_image_list(self, images):
"""
Preprocess a list of PIL images for both condition encoder and VAE.

Args:
images: List of PIL images

Returns:
Tuple of (condition_sizes, condition_images, vae_sizes, vae_images)
"""
condition_sizes = []
condition_images = []
vae_sizes = []
vae_images = []

for img in images:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(CONDITION_IMAGE_SIZE, image_width / image_height)
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
condition_sizes.append((condition_width, condition_height))
vae_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

return condition_sizes, condition_images, vae_sizes, vae_images

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
Expand All @@ -434,6 +460,18 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):

return image_latents

def _encode_and_pack_image(self, image, num_channels_latents, device, dtype, generator):
"""Encode a single image and pack it. Returns packed latents."""
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
img_latents = self._encode_vae_image(image=image, generator=generator)
else:
img_latents = image

image_latent_height, image_latent_width = img_latents.shape[3:]
img_latents = self._pack_latents(img_latents, 1, num_channels_latents, image_latent_height, image_latent_width)
return img_latents

def prepare_latents(
self,
images,
Expand All @@ -457,30 +495,28 @@ def prepare_latents(
if images is not None:
if not isinstance(images, list):
images = [images]
all_image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)

image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
all_image_latents.append(image_latents)
image_latents = torch.cat(all_image_latents, dim=1)
# Check if nested list (batch_size > 1): [[img1, img2], [img3, img4]]
is_nested = images and isinstance(images[0], list)

if is_nested:
# batch_size > 1: Process each batch item separately
batch_image_latents = []
for batch_images in images:
batch_item_latents = [
self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator)
for img in batch_images
]
# Concatenate all images for this batch item along sequence dimension
batch_image_latents.append(torch.cat(batch_item_latents, dim=1))
# Stack all batch items to create final batch dimension
image_latents = torch.cat(batch_image_latents, dim=0)
else:
# batch_size == 1: Process flat list [img1, img2]
all_image_latents = [
self._encode_and_pack_image(img, num_channels_latents, device, dtype, generator) for img in images
]
image_latents = torch.cat(all_image_latents, dim=1)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
Expand Down Expand Up @@ -546,12 +582,15 @@ def __call__(
Function invoked when calling the pipeline for generation.

Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, or `List[List[PIL.Image.Image]]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
latents as `image`, but if passing latents directly it is not encoded again. For batch processing with
multiple prompts (batch_size > 1), provide a nested list where each sublist contains the input images
for that prompt: `[[img1_for_prompt1], [img2_for_prompt2]]`. For a single prompt with multiple
reference images (batch_size == 1), use a flat list: `[img1, img2]`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
Expand Down Expand Up @@ -630,7 +669,17 @@ def __call__(
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[-1].size if isinstance(image, list) else image.size
# Handle both flat list [img1, img2] and nested list [[img1, img2], [img3, img4]]
if isinstance(image, list):
# Check if nested list (batch_size > 1)
if isinstance(image[0], list):
# Use last image from first batch item
image_size = image[0][-1].size
else:
# Flat list (batch_size == 1)
image_size = image[-1].size
else:
image_size = image.size
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
Expand Down Expand Up @@ -666,32 +715,38 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]

# QwenImageEditPlusPipeline does not currently support batch_size > 1
if batch_size > 1:
raise ValueError(
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
"Please process prompts one at a time."
)

device = self._execution_device
# 3. Preprocess image
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
if not isinstance(image, list):
image = [image]
condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []
for img in image:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(
CONDITION_IMAGE_SIZE, image_width / image_height

# Check if nested list (batch_size > 1) or flat list (batch_size == 1)
is_nested = isinstance(image[0], list)

if is_nested:
if batch_size > 1 and len(image) != batch_size:
raise ValueError(
f"Image batch_size ({len(image)}) must match batch_size for prompts ({batch_size}) for batch inference."
)
# batch_size > 1: image = [[img1, img2], [img3, img4]]
# Process each batch item separately
condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []

for batch_images in image:
cond_sizes, cond_imgs, vae_szs, vae_imgs = self._preprocess_image_list(batch_images)
condition_image_sizes.append(cond_sizes)
condition_images.append(cond_imgs)
vae_image_sizes.append(vae_szs)
vae_images.append(vae_imgs)
else:
# batch_size == 1: image = [img1, img2]
condition_image_sizes, condition_images, vae_image_sizes, vae_images = self._preprocess_image_list(
image
)
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
condition_image_sizes.append((condition_width, condition_height))
vae_image_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
Expand Down Expand Up @@ -740,15 +795,19 @@ def __call__(
generator,
latents,
)
# Build img_shapes for each batch item (avoid shared references!)
# Normalize vae_image_sizes to nested list format for uniform processing
sizes_list = vae_image_sizes if is_nested else [vae_image_sizes]
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
*[
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
for vae_width, vae_height in vae_image_sizes
for vae_width, vae_height in batch_vae_sizes
],
]
] * batch_size
for batch_vae_sizes in sizes_list
]

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
Expand Down
54 changes: 46 additions & 8 deletions tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,52 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt():
@pytest.mark.xfail(
condition=True,
reason="num_images_per_prompt > 1 is not yet supported for EditPlus pipeline",
strict=True,
)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()
def test_inference_batch_single_identical(self):
# Test that batch_size=1 gives identical results to non-batched inference
self._test_inference_batch_single_identical(expected_max_diff=1e-3)

def test_inference_batch_consistent(self):
# Test that batched inference gives consistent results
self._test_inference_batch_consistent()

def test_batch_processing_multiple_prompts(self):
# Test batch processing with multiple prompts (batch_size > 1)
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

if str(device).startswith("mps"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=device).manual_seed(0)

image = Image.new("RGB", (32, 32))

# Test with nested list format for batch_size=2
inputs = {
"prompt": ["dance monkey", "jump around"],
"image": [[image], [image]], # Nested list for batch_size=2
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}

images = pipe(**inputs).images

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
# Should return 2 images (batch_size=2)
self.assertEqual(len(images), 2)
self.assertEqual(images[0].shape, (3, 32, 32))
self.assertEqual(images[1].shape, (3, 32, 32))
Loading