Skip to content

feat: support ring attention with arbitrary KV sequence lengths#13545

Open
songh11 wants to merge 4 commits intohuggingface:mainfrom
songh11:ring_attn_anything
Open

feat: support ring attention with arbitrary KV sequence lengths#13545
songh11 wants to merge 4 commits intohuggingface:mainfrom
songh11:ring_attn_anything

Conversation

@songh11
Copy link
Copy Markdown
Contributor

@songh11 songh11 commented Apr 22, 2026

What does this PR do?

Adds a new "Ring Anything" context-parallel attention mode that supports arbitrary
(non-evenly divisible) KV sequence lengths across ring-degree workers.

Motivation

Existing TemplatedRingAttention requires KV to be equipartitioned across ranks,
which is impractical for real-world workloads where per-rank sequence lengths can
differ (e.g., variable-length prompts, packed batches, token pruning). This PR
mirrors the existing ulysses_anything design but applies it to the ring path.

Changes

  • ContextParallelConfig: add ring_anything flag with validation
    (ring_degree > 1 and ulysses_degree == 1).
  • TemplatedRingAnythingAttention: new autograd Function that
    • all-gathers per-rank KV sequence lengths,
    • pads local KV to the global max length,
    • all-gathers padded KV and slices back to each rank's true length,
    • runs online-softmax merging as in ring attention.
  • _templated_context_parallel_attention: dispatch to the new class when
    ring_anything is enabled.
  • ContextParallelSplitHook: route through PartitionAnythingSharder.shard_anything
    when ring_anything is set.

Reproducible example

import argparse

import torch
import torch.distributed as dist
from diffusers import QwenImagePipeline
from diffusers.models._modeling_parallel import ContextParallelConfig


model_name = "Qwen/Qwen-Image-2512/"

# Generate image
prompt = '''A 20-year-old East Asian girl with delicate, charming features and large, bright brown eyes—expressive and lively, with a cheerful or subtly smiling expression. Her naturally wavy long hair is either loose or tied in twin ponytails. She has fair skin and light makeup accentuating her youthful freshness. She wears a modern, cute dress or relaxed outfit in bright, soft colors—lightweight fabric, minimalist cut. She stands indoors at an anime convention, surrounded by banners, posters, or stalls. Lighting is typical indoor illumination—no staged lighting—and the image resembles a casual iPhone snapshot: unpretentious composition, yet brimming with vivid, fresh, youthful charm.'''
negative_prompt = "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。"

# Generate with different aspect ratios
aspect_ratios = {
    "1:1": (1328, 1328),
    "16:9": (1664, 928),
    "9:16": (928, 1664),
    "4:3": (1472, 1104),
    "3:4": (1104, 1472),
    "3:2": (1584, 1056),
    "2:3": (1056, 1584),
}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ring",    type=int, default=1)
    parser.add_argument("--ulysses", type=int, default=2)
    args = parser.parse_args()

    # 1. init dist — torchrun injects RANK / LOCAL_RANK / WORLD_SIZE automatically
    dist.init_process_group(backend="nccl")
    local_rank = dist.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    # 2. load model
    pipe = QwenImagePipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
    pipe.to(f"cuda:{local_rank}")

    # 3. attn backend
    if args.ring > 1:
        pipe.transformer.set_attention_backend("_native_cudnn")
        print("enable _native_cudnn backend")

    # 4. build parallel config via optkit and enable CP on the transformer
    parallelism_config = ContextParallelConfig(
        ring_degree=args.ring,
        ulysses_degree=args.ulysses,
        ulysses_anything=False,
        ring_anything=True,
    )
    pipe.transformer.enable_parallelism(config=parallelism_config)

    # 5. inference — all ranks must enter together for collectives to align
    width, height = aspect_ratios["16:9"]
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_inference_steps=50,
        true_cfg_scale=1.0,
        generator=torch.Generator(device="cuda").manual_seed(42)
    ).images[0]

    # 6. save on rank 0 only
    if dist.get_rank() == 0:
        image.save(f"qwen_image_cp_ring_{args.ring}_ulysses_{args.ulysses}.png")
        print(f"Saved qwen_image_cp_ring_{args.ring}_ulysses_{args.ulysses}.png")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Launch

torchrun --nproc_per_node=4 demo.py --ulysses 1 --ring 4

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @asomoza @sayakpaul

@github-actions github-actions Bot added models hooks size/M PR with diff < 200 LOC labels Apr 22, 2026
@sayakpaul
Copy link
Copy Markdown
Member

Thanks for your PR. Could you also provide some visual output examples and comparison with other backends we support (latency, memory overhead, etc.)?

@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Apr 22, 2026
@github-actions github-actions Bot added documentation Improvements or additions to documentation size/L PR with diff > 200 LOC and removed size/M PR with diff < 200 LOC labels Apr 23, 2026
@songh11
Copy link
Copy Markdown
Contributor Author

songh11 commented Apr 23, 2026

Thanks for your PR. Could you also provide some visual output examples and comparison with other backends we support (latency, memory overhead, etc.)?

@sayakpaul Thanks for the review! I just pushed two follow-up commits:

  • 894e289 — fix: align ring_anything with the ulysses_anything path (use shared gather_size_by_comm, handle ring_anything in unshard hook).
  • 09b3995 — docs: add ### Ring Anything Attention section in distributed_inference.md with config example, gloo TIP, current limitations, and the benchmark below.

FLUX.1-dev benchmarks on 4 × RTX 4090 (48GB)

CP Backend Time / Iter (ms) Steps / Sec Peak Memory (GB) Shape (HxW)
ulysses 259.07 3.86 33.83 1024x1024
ring 338.98 2.95 33.83 1024x1024
unified_balanced 321.54 3.11 33.83 1024x1024
ulysses_anything 259.07 3.86 33.83 1024x1024
ring_anything 340.14 2.94 33.83 1024x1024
ulysses failed failed failed 1008x1008
ring failed failed failed 1008x1008
unified_balanced failed failed failed 1008x1008
ulysses_anything 253.16 3.95 33.75 1008x1008
ring_anything 335.57 2.98 33.75 1008x1008

ring_anything matches ring within ~0.4% latency and identical peak memory at evenly-divisible shapes, while it is the only ring path that works at non-divisible shapes such as 1008×1008. Full table is in the docs.

Visual outputs

I'll attach side-by-side generated images (same prompt + seed: single-GPU baseline vs. ring_anything) in a follow-up comment shortly.

@songh11
Copy link
Copy Markdown
Contributor Author

songh11 commented Apr 23, 2026

Here are the visual outputs across all backends (same prompt, same seed).

Outputs

Comparison of generated images across different context-parallel backends.

resolution ulysses ring unified_balanced ulysses_anything ring_anything
1024×1024 output_1024x1024_ulysses output_1024x1024_ring output_1024x1024_unified output_1024x1024_ulysses_anything output_1024x1024_ring_anything
1008×1008 failed failed failed output_1008x1008_ulysses_anything output_1008x1008_ring_anything

Reproduction

demo script
import os
import time
import argparse
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig


def make_step_timing_callback() -> Tuple[Any, List[float]]:
    step_times: List[float] = []
    last_time: Optional[float] = None

    def callback(pipe, step: int, timestep: int, callback_kwargs: Dict[str, Any]):
        nonlocal last_time
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        now = time.perf_counter()
        if last_time is not None:
            step_times.append(now - last_time)
        last_time = now
        return callback_kwargs

    return callback, step_times


def measure_memory():
    """Measure current GPU memory usage"""
    allocated = torch.cuda.memory_allocated() / 1024**3  # GB
    reserved = torch.cuda.memory_reserved() / 1024**3  # GB
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    return {
        "allocated_gb": allocated,
        "reserved_gb": reserved,
        "max_allocated_gb": max_allocated,
    }

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified", 'ulysses_anything', 'ring_anything'],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1024,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1024,
        help="Width of the generated image.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()

    pipeline = DiffusionPipeline.from_pretrained(
        os.environ.get(
            "FLUX_DIR",
            "black-forest-labs/FLUX.1-dev",
        ),
        torch_dtype=torch.bfloat16,
    ).to(device)

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

    # Always using it because `ring` doesn't support default. This helps ensure a fair comparison.
    pipeline.transformer.set_attention_backend("_native_cudnn")

    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    elif args.cp_backend == "ulysses":
        cp_config = ContextParallelConfig(ulysses_degree=world_size)
    elif args.cp_backend == "ulysses_anything":
        cp_config = ContextParallelConfig(ulysses_degree=world_size, ulysses_anything=True)
    elif args.cp_backend == "ring_anything":
        cp_config = ContextParallelConfig(ring_degree=world_size, ring_anything=True)
    else:
        raise ValueError(f"Unsupported cp_backend: {args.cp_backend}")

    pipeline.transformer.enable_parallelism(config=cp_config)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    def run_pipe(pipeline, steps=50, callback=None):
        image = pipeline(
            prompt,
            guidance_scale=3.5,
            num_inference_steps=steps,
            generator=torch.Generator().manual_seed(42),
            height=args.height,
            width=args.width,
            callback_on_step_end=callback,
        ).images[0]
        return image

    # warmup
    _ = run_pipe(pipeline, steps=10)

    callback, step_times = make_step_timing_callback()

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.perf_counter()
    image = run_pipe(pipeline, steps=50, callback=callback)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end = time.perf_counter()

    if dist.get_rank() == 0:
        print(f"{args.cp_backend=}")
        save_path = f"output_{args.height}x{args.width}_{args.cp_backend}.png"
        image.save(save_path)
        print(f"Saved image to {save_path}, time taken: {end - start:.2f} seconds")
        total_time = sum(step_times)
        iters = len(step_times)
        time_per_iter = (total_time / iters) * 1000  # in milliseconds
        steps_per_sec = iters / total_time
        print(f"Time per Iter: {time_per_iter:.2f} ms, Steps per Sec: {steps_per_sec:.2f}")

        max_allocated_gb = measure_memory()["max_allocated_gb"]
        print(f"{max_allocated_gb=}")


    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

# Example usage:
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ring # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend unified # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ulysses_anything # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --cp-backend ring_anything # success

# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ring # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend unified # failed
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ulysses_anything # success
# torchrun --nproc_per_node=4 check_ulysses_anything.py --height 1008 --width 1008 --cp-backend ring_anything # success

@sayakpaul
Copy link
Copy Markdown
Member

Thanks for the updates -- they are very comprehensive. It seems like ulysses_anything is still the best but users who might prefer ring would want to use ring_anything for non-divisible shapes? Is that the motivation here?

@songh11
Copy link
Copy Markdown
Contributor Author

songh11 commented Apr 23, 2026

Thanks for the updates -- they are very comprehensive. It seems like ulysses_anything is still the best but users who might prefer ring would want to use ring_anything for non-divisible shapes? Is that the motivation here?

Thanks for the review!

You're spot on — the primary motivation is to provide a seamless path for users on a ring topology to handle non-divisible sequence lengths, similar to how ulysses_anything works.

While ulysses/ulysses_anything is often the go-to, I believe ring_anything is a valuable addition for a few key reasons:

  1. Zero Head-count Constraints: Unlike Ulysses, Ring doesn't require num_heads to be divisible by the degree. For models with irregular or small head counts, this provides a much-needed "fallthrough" option.
  2. Performance on PCIe/P2P: In my experiments with dual-GPU PCIe setups, P2P ring communication can be more efficient than Ulysses' all-to-all due to reduced bandwidth contention across the PCIe switch.
  3. Future work: This PR focuses on ensuring correctness and supporting general shapes. I plan to follow up with another PR to introduce a proper P2P implementation to fully unlock the computation-communication overlap benefits.

So while ulysses/ulysses_anything remains the recommended choice for most, while ring_anything covers head-constrained or ring-first setups—focusing on correctness today, with clear headroom for P2P overlap in the next step.

@sayakpaul
Copy link
Copy Markdown
Member

Your reasoning is spot on! Let's include them in the docs too.

@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review here?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

Claude finished @sayakpaul's task in 2m 29s —— View job


Review of PR #13545: Ring Anything Attention

  • Read repository guidelines (AGENTS.md, models.md, modular.md)
  • Read changed files and understand the implementation
  • Compare with existing ulysses_anything and TemplatedRingAttention patterns
  • Post detailed review feedback

Overall this is a well-structured PR that cleanly mirrors the existing patterns. The core algorithm (pad → all-gather → slice → online-softmax merge) is correct, and the config validation + hook routing are consistent with ulysses_anything. A few items below.


Duplicate hook branches should be merged

src/diffusers/hooks/context_parallel.py — both _prepare_cp_input (lines 213–220) and post_forward (lines 246–253) now have two identical branches:

if self.parallel_config.ulysses_anything:
    return PartitionAnythingSharder.shard_anything(...)
elif self.parallel_config.ring_anything:
    return PartitionAnythingSharder.shard_anything(...)

Since the bodies are identical, these should be merged:

if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
    return PartitionAnythingSharder.shard_anything(...)

Same for the unshard_anything path in post_forward. This reduces duplication and makes it clearer that both modes use the same shard/unshard strategy.

Fix this →


Incomplete docstring comment for ring_anything field

src/diffusers/models/_modeling_parallel.py:88:

# Whether to enable ring anything attention to support sequence lengths.

This should read "to support arbitrary sequence lengths" (matching the docstring at line 68 which already says "arbitrary sequence lengths"). Minor, but worth keeping the inline comment consistent.


Memory considerations worth documenting

The current implementation pads all local KV to s_max before the all-gather. When sequence lengths are highly uneven across ranks, this communicates and allocates extra zeros. This is fine for a correctness-first approach (and the benchmarks confirm negligible overhead for near-uniform splits), but it would be good to add a brief inline comment noting this trade-off, especially since the PR description mentions a follow-up P2P implementation.


Minor: validation could be more explicit about mutual exclusivity

The current constraints in __post_init__ (lines 117–126) implicitly prevent ulysses_anything and ring_anything from being enabled simultaneously (since ulysses_anything requires ring_degree == 1 while ring_anything requires ulysses_degree == 1). Consider adding an explicit check:

if self.ulysses_anything and self.ring_anything:
    raise ValueError("ulysses_anything and ring_anything cannot both be enabled.")

This makes the constraint immediately obvious to readers without requiring them to reason through the degree constraints. It's also more robust against future changes to the degree validation.


Overall assessment

The core attention logic is correct:

  • Pad → all-gather → slice correctly reconstructs each rank's true KV before calling the attention op (zeros never enter attention).
  • The online-softmax merging loop is identical to TemplatedRingAttention and handles the variable-length chunks correctly.
  • Config validation prevents invalid combinations.
  • The hook routing through PartitionAnythingSharder correctly handles non-uniform splits.
  • Backward raising NotImplementedError is consistent with TemplatedUlyssesAnythingAttention.

The benchmarks are thorough and confirm that ring_anything matches ring at divisible shapes (~0.4% overhead) while being the only ring-based mode that works at non-divisible shapes. The motivation (head-count constraints, PCIe/P2P topology advantages) is compelling.

The main actionable items are: (1) merging the duplicate hook branches, and (2) adding the explicit mutual-exclusivity check. The rest are minor polish.


@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 24, 2026
@songh11
Copy link
Copy Markdown
Contributor Author

songh11 commented Apr 24, 2026

@sayakpaul Thanks! Addressed the feedback by merging duplicate hooks, adding the exclusivity check, and polishing comments. Pushed in the latest commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation hooks models performance Anything related to performance improvements, profiling and benchmarking size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants