[Feat] Adds LongCat-AudioDiT pipeline #13390
[Feat] Adds LongCat-AudioDiT pipeline #13390RuixiangMa wants to merge 6 commits intohuggingface:mainfrom
Conversation
Signed-off-by: Lancer <maruixiang6688@gmail.com>
9c4613f to
d2a2621
Compare
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
|
|
||
| def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: |
There was a problem hiding this comment.
Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| self.time_embed = AudioDiTTimestepEmbedding(dim) | ||
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | ||
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | ||
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | ||
| self.blocks = nn.ModuleList( |
There was a problem hiding this comment.
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( | |
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( |
See #13390 (comment).
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| batch_size = hidden_states.shape[0] | ||
| if timestep.ndim == 0: | ||
| timestep = timestep.repeat(batch_size) | ||
| timestep_embed = self.time_embed(timestep) | ||
| text_mask = encoder_attention_mask.bool() | ||
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
There was a problem hiding this comment.
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) | |
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:
There was a problem hiding this comment.
Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.
src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| for idx in range(len(timesteps) - 1): | ||
| curr_t = timesteps[idx] | ||
| dt = timesteps[idx + 1] - timesteps[idx] | ||
| sample = sample + model_step(curr_t, sample) * dt |
There was a problem hiding this comment.
I think we should use a scheduler, most likely FlowMatchEulerDiscreteScheduler, here instead of implementing the sampling algorithm inside __call__.
There was a problem hiding this comment.
I checked this further, but I don’t think it is worth changing right now.
Using FlowMatchEulerDiscreteScheduler changed the fixed-seed output, so it was not behavior-preserving. Since the current sampling loop is already very small, I kept the existing implementation to avoid changing generation behavior. Do you think it still makes sense to use a scheduler here?
There was a problem hiding this comment.
I think we should use a scheduler for the following reasons:
- Abstracting the sampling logic into a scheduler is part of the
diffuserspipeline design; almost all (if not all) pipelines follow this design. - Since the model is trained using flow matching, running inference with ODE solvers other than the Euler solver should still result in valid samples. Separating out the ODE solver logic into a scheduler would allow users to easily change solvers by changing the scheduler if they wish. For example, if a user wanted to use a higher-order solver, they could use
FlowMatchHeunDiscreteSchedulerinstead ofFlowMatchEulerDiscreteScheduler. - By using a scheduler, users could also use variants of the Euler solver if they wished. For example, users would be able to change the
shiftwithout changing the pipeline code by using a newFlowMatchEulerDiscreteSchedulerwith the desiredshiftvalue.
FlowMatchEulerDiscreteScheduler should be able to support this use case. I think it uses a diffusion-style time convention where 1 is noise and 0 is data by default, but a flow matching-style time convention where 0 is noise and 1 is data (which the pipeline currently uses) should be supported through invert_sigmas=True.
src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Thx for the comments, I have made the changes, PTAL. |
| mask: torch.BoolTensor | None = None, | ||
| rope: tuple | None = None, |
There was a problem hiding this comment.
| mask: torch.BoolTensor | None = None, | |
| rope: tuple | None = None, | |
| attention_mask: torch.BoolTensor | None = None, | |
| audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
nit: follow naming in AudioDiTCrossAttnProcessor.
| mask: torch.BoolTensor | None = None, | ||
| rope: tuple | None = None, | ||
| ) -> torch.Tensor: | ||
| if encoder_hidden_states is None: | ||
| return self.processor(self, hidden_states, mask=mask, rope=rope) |
There was a problem hiding this comment.
| mask: torch.BoolTensor | None = None, | |
| rope: tuple | None = None, | |
| ) -> torch.Tensor: | |
| if encoder_hidden_states is None: | |
| return self.processor(self, hidden_states, mask=mask, rope=rope) | |
| ) -> torch.Tensor: | |
| if encoder_hidden_states is None: | |
| return self.processor(self, hidden_states, attention_mask, audio_rotary_emb) |
nit: I think this would be a more clear way to support both AudioDiTSelfAttnProcessor and AudioDiTCrossAttnProcessor. Depends on the suggestion from #13390 (comment).
| self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) | ||
| self.vae_scale_factor = self.latent_hop |
There was a problem hiding this comment.
| self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) | |
| self.vae_scale_factor = self.latent_hop | |
| self.vae_scale_factor = getattr(vae.config, "downsampling_ratio", 2048) |
nit: I think we can remove self.latent_hop as it is no longer directly used.
| from diffusers import LongCatAudioDiTTransformer | ||
|
|
||
|
|
||
| def test_longcat_audio_transformer_forward_shape(): |
There was a problem hiding this comment.
Can you also add transformer tests via
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_longcat_audio_dit.py? This adds standard tests for model inference, torch.compile compatibility, etc.
| ) | ||
|
|
||
|
|
||
| class LongCatAudioDiTPipelineFastTests(unittest.TestCase): |
There was a problem hiding this comment.
| class LongCatAudioDiTPipelineFastTests(unittest.TestCase): | |
| class LongCatAudioDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
Fast test classes should inherit from PipelineTesterMixin, which adds standard tests for pipelines.
| from transformers import UMT5Config, UMT5EncoderModel | ||
|
|
||
| from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae | ||
| from tests.testing_utils import require_torch_accelerator, slow, torch_device |
There was a problem hiding this comment.
| from tests.testing_utils import require_torch_accelerator, slow, torch_device | |
| from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device | |
| from ..test_pipelines_common import PipelineTesterMixin | |
| enable_full_determinism() |
Follow up change to #13390 (comment). We also import and call enable_full_determinism, which helps ensure that the tests are deterministic and thus have more consistent behavior. (enable_full_determinism should be used in the transformer tests as well.)
|
|
||
| @slow | ||
| @require_torch_accelerator | ||
| def test_longcat_audio_pipeline_from_pretrained_real_local_weights(): |
There was a problem hiding this comment.
Can you refactor test_longcat_audio_pipeline_from_pretrained_real_local_weights to be part of a LongCatAudioDiTPipelineSlowTests class? For reference, see the Stable Diffusion 3 slow tests:
dg845
left a comment
There was a problem hiding this comment.
Thanks for iterating! I left some follow-up comments.
What does this PR do?
Adds LongCat-AudioDiT model support to diffusers.
Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.
Test
Result
longcat.wav
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.