Skip to content

[Feat] Adds LongCat-AudioDiT pipeline #13390

Open
RuixiangMa wants to merge 6 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit
Open

[Feat] Adds LongCat-AudioDiT pipeline #13390
RuixiangMa wants to merge 6 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit

Conversation

@RuixiangMa
Copy link
Copy Markdown

@RuixiangMa RuixiangMa commented Apr 2, 2026

What does this PR do?

Adds LongCat-AudioDiT model support to diffusers.

Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.

Test

import soundfile as sf
import torch
from diffusers import LongCatAudioDiTPipeline

pipeline = LongCatAudioDiTPipeline.from_pretrained(
    "meituan-longcat/LongCat-AudioDiT-1B",
    torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")

audio = pipeline(
    prompt="A calm ocean wave ambience with soft wind in the background.",
    audio_end_in_s=5.0,
    num_inference_steps=16,
    guidance_scale=4.0,
    output_type="pt",
).audios

output = audio[0, 0].float().cpu().numpy()
sf.write("longcat.wav", output, pipeline.sample_rate)

Result

longcat.wav

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Longcataudiodit [Feat] Adds LongCat-AudioDiT support Apr 2, 2026
@RuixiangMa RuixiangMa changed the title [Feat] Adds LongCat-AudioDiT support [Feat] Adds LongCat-AudioDiT pipeline Apr 2, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@dg845 dg845 requested review from dg845 and yiyixuxu April 4, 2026 00:31
)


def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).

Comment on lines +515 to +519
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(

See #13390 (comment).

Comment on lines +584 to +589
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)

Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.

for idx in range(len(timesteps) - 1):
curr_t = timesteps[idx]
dt = timesteps[idx + 1] - timesteps[idx]
sample = sample + model_step(curr_t, sample) * dt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use a scheduler, most likely FlowMatchEulerDiscreteScheduler, here instead of implementing the sampling algorithm inside __call__.

prev_sample = sample + dt * model_output

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this further, but I don’t think it is worth changing right now.

Using FlowMatchEulerDiscreteScheduler changed the fixed-seed output, so it was not behavior-preserving. Since the current sampling loop is already very small, I kept the existing implementation to avoid changing generation behavior. Do you think it still makes sense to use a scheduler here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use a scheduler for the following reasons:

  1. Abstracting the sampling logic into a scheduler is part of the diffusers pipeline design; almost all (if not all) pipelines follow this design.
  2. Since the model is trained using flow matching, running inference with ODE solvers other than the Euler solver should still result in valid samples. Separating out the ODE solver logic into a scheduler would allow users to easily change solvers by changing the scheduler if they wish. For example, if a user wanted to use a higher-order solver, they could use FlowMatchHeunDiscreteScheduler instead of FlowMatchEulerDiscreteScheduler.
  3. By using a scheduler, users could also use variants of the Euler solver if they wished. For example, users would be able to change the shift without changing the pipeline code by using a new FlowMatchEulerDiscreteScheduler with the desired shift value.

FlowMatchEulerDiscreteScheduler should be able to support this use case. I think it uses a diffusion-style time convention where 1 is noise and 0 is data by default, but a flow matching-style time convention where 0 is noise and 1 is data (which the pipeline currently uses) should be supported through invert_sigmas=True.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
@github-actions github-actions bot added documentation Improvements or additions to documentation models tests pipelines size/L PR with diff > 200 LOC labels Apr 7, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 8, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 8, 2026
@RuixiangMa
Copy link
Copy Markdown
Author

Thanks for the PR! Left an initial design review :).

Thx for the comments, I have made the changes, PTAL.

Comment on lines +191 to +192
mask: torch.BoolTensor | None = None,
rope: tuple | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask: torch.BoolTensor | None = None,
rope: tuple | None = None,
attention_mask: torch.BoolTensor | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,

nit: follow naming in AudioDiTCrossAttnProcessor.

Comment on lines +264 to +268
mask: torch.BoolTensor | None = None,
rope: tuple | None = None,
) -> torch.Tensor:
if encoder_hidden_states is None:
return self.processor(self, hidden_states, mask=mask, rope=rope)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask: torch.BoolTensor | None = None,
rope: tuple | None = None,
) -> torch.Tensor:
if encoder_hidden_states is None:
return self.processor(self, hidden_states, mask=mask, rope=rope)
) -> torch.Tensor:
if encoder_hidden_states is None:
return self.processor(self, hidden_states, attention_mask, audio_rotary_emb)

nit: I think this would be a more clear way to support both AudioDiTSelfAttnProcessor and AudioDiTCrossAttnProcessor. Depends on the suggestion from #13390 (comment).

Comment on lines +174 to +175
self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048)
self.vae_scale_factor = self.latent_hop
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048)
self.vae_scale_factor = self.latent_hop
self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048)

nit: I think we can remove self.latent_hop as it is no longer directly used.

from diffusers import LongCatAudioDiTTransformer


def test_longcat_audio_transformer_forward_shape():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add transformer tests via

python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_longcat_audio_dit.py

? This adds standard tests for model inference, torch.compile compatibility, etc.

)


class LongCatAudioDiTPipelineFastTests(unittest.TestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class LongCatAudioDiTPipelineFastTests(unittest.TestCase):
class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):

Fast test classes should inherit from PipelineTesterMixin, which adds standard tests for pipelines.

from transformers import UMT5Config, UMT5EncoderModel

from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae
from tests.testing_utils import require_torch_accelerator, slow, torch_device
Copy link
Copy Markdown
Collaborator

@dg845 dg845 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from tests.testing_utils import require_torch_accelerator, slow, torch_device
from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()

Follow up change to #13390 (comment). We also import and call enable_full_determinism, which helps ensure that the tests are deterministic and thus have more consistent behavior. (enable_full_determinism should be used in the transformer tests as well.)


@slow
@require_torch_accelerator
def test_longcat_audio_pipeline_from_pretrained_real_local_weights():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor test_longcat_audio_pipeline_from_pretrained_real_local_weights to be part of a LongCatAudioDiTPipelineSlowTests class? For reference, see the Stable Diffusion 3 slow tests:

@slow
@require_big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! I left some follow-up comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants