Skip to content

Add chunk_gated_delta_rule triton kernel for CUDA backend#18138

Open
mergennachin wants to merge 2 commits intomainfrom
mergennachin/fla_linear_attention
Open

Add chunk_gated_delta_rule triton kernel for CUDA backend#18138
mergennachin wants to merge 2 commits intomainfrom
mergennachin/fla_linear_attention

Conversation

@mergennachin
Copy link
Contributor

Registers FLA's chunk_gated_delta_rule as a @triton_op, following the
same pattern as the existing SDPA triton kernel. Six FLA triton kernels
are launched via wrap_triton() so AOTInductor compiles them directly
into the generated .so — no C++ shim needed.

Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't
support. We unwrap via kernel.fn to get the inner @triton.autotune
kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly.

Requires: pip install flash-linear-attention

Copilot AI review requested due to automatic review settings March 12, 2026 19:40
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18138

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 8 Unrelated Failures

As of commit 530ddb2 with merge base cc27e6b (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 12, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.



@chunk_gated_delta_rule.register_fake
def _chunk_gated_delta_rule_fake(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why fake instead of meta? whats the difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SDPA kernel in this repo uses register_fake so I followed the same convention.

CHUNK_SIZE = 64


def _unwrap(kernel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats going on with the unwrap stuff?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ExecuTorch CUDA Triton custom op wrapper for Flash-Linear-Attention (FLA)’s chunk_gated_delta_rule, plus end-to-end validation via Python export tests and a small C++ runner that executes the exported .pte with the CUDA delegate in CI.

Changes:

  • Introduces triton::chunk_gated_delta_rule as a @triton_op, wrapping multiple FLA Triton kernels via wrap_triton() for AOTInductor compilation.
  • Adds Python tests to validate eager correctness vs FLA and to export/lower the op to an ExecuTorch program.
  • Adds a C++ e2e runner and wires CI to install FLA, run tests, export a model, build the runner, and execute it.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
backends/cuda/triton/kernels/chunk_gated_delta_rule.py Registers the new Triton op and launches the underlying FLA kernels via wrap_triton().
backends/cuda/triton/kernels/init.py Conditionally imports/registers the new op when FLA is available.
backends/cuda/tests/test_chunk_gated_delta_rule.py Adds eager + export/lower tests for the new op (skipped when FLA isn’t installed).
backends/cuda/tests/chunk_gated_delta_runner/main.cpp Adds a minimal C++ program to load and run the exported model with CUDA delegate.
backends/cuda/tests/chunk_gated_delta_runner/CMakeLists.txt Adds build configuration for the new C++ runner.
.github/workflows/cuda.yml Installs FLA and runs the new Python + C++ e2e coverage in CUDA CI.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +143 to +147
V=V,
BT=BT,
BK=64,
BV=64,
USE_G=True,
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BK and BV are hard-coded to 64 when launching recompute_w_u_fwd_kernel, which implicitly constrains supported head/value dims. Either derive these launch-time constexprs from K/V (if the FLA kernels support it) or validate/document the required K/V constraints in the @triton_op docstring and input checks to prevent confusing shape-dependent failures.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 0550e5a to 43ee833 Compare March 12, 2026 20:45
Copilot AI review requested due to automatic review settings March 12, 2026 21:16
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 43ee833 to 8cd3f9c Compare March 12, 2026 21:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 8cd3f9c to 7d132d1 Compare March 12, 2026 21:31
Copilot AI review requested due to automatic review settings March 12, 2026 21:41
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 7d132d1 to 4f37bc7 Compare March 12, 2026 21:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.


pte_path = os.path.join(output_dir, "chunk_gated_delta.pte")
with open(pte_path, "wb") as f:
f.write(et_program.buffer)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

export_chunk_gated_delta() writes the program using et_program.buffer, which forces materializing the entire serialized program into a contiguous bytes object. ExecuTorchProgramManager explicitly recommends write_to_file() to avoid extra copies and reduce peak memory. Consider switching to et_program.write_to_file(f) here.

Suggested change
f.write(et_program.buffer)
et_program.write_to_file(f)

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 4f37bc7 to 97beeeb Compare March 12, 2026 22:07
Copilot AI review requested due to automatic review settings March 12, 2026 22:08
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 97beeeb to 6ee0408 Compare March 12, 2026 22:08
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Registers FLA's chunk_gated_delta_rule as a @triton_op, following the
same pattern as the existing SDPA triton kernel. Six FLA triton kernels
are launched via wrap_triton() so AOTInductor compiles them directly
into the generated .so — no C++ shim needed.

Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't
support. We unwrap via kernel.fn to get the inner @triton.autotune
kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly.

Requires: pip install flash-linear-attention
@mergennachin mergennachin force-pushed the mergennachin/fla_linear_attention branch from 6ee0408 to 21c5dd7 Compare March 12, 2026 22:34
return o, final_state


def _make_inputs_from_fla(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we need to test nan and inf and -inf here, follow what we learned before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants