Add chunk_gated_delta_rule triton kernel for CUDA backend#18138
Add chunk_gated_delta_rule triton kernel for CUDA backend#18138mergennachin wants to merge 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18138
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 8 Unrelated FailuresAs of commit 530ddb2 with merge base cc27e6b ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
|
||
|
|
||
| @chunk_gated_delta_rule.register_fake | ||
| def _chunk_gated_delta_rule_fake( |
There was a problem hiding this comment.
why fake instead of meta? whats the difference?
There was a problem hiding this comment.
The SDPA kernel in this repo uses register_fake so I followed the same convention.
| CHUNK_SIZE = 64 | ||
|
|
||
|
|
||
| def _unwrap(kernel): |
There was a problem hiding this comment.
whats going on with the unwrap stuff?
There was a problem hiding this comment.
Pull request overview
Adds an ExecuTorch CUDA Triton custom op wrapper for Flash-Linear-Attention (FLA)’s chunk_gated_delta_rule, plus end-to-end validation via Python export tests and a small C++ runner that executes the exported .pte with the CUDA delegate in CI.
Changes:
- Introduces
triton::chunk_gated_delta_ruleas a@triton_op, wrapping multiple FLA Triton kernels viawrap_triton()for AOTInductor compilation. - Adds Python tests to validate eager correctness vs FLA and to export/lower the op to an ExecuTorch program.
- Adds a C++ e2e runner and wires CI to install FLA, run tests, export a model, build the runner, and execute it.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cuda/triton/kernels/chunk_gated_delta_rule.py | Registers the new Triton op and launches the underlying FLA kernels via wrap_triton(). |
| backends/cuda/triton/kernels/init.py | Conditionally imports/registers the new op when FLA is available. |
| backends/cuda/tests/test_chunk_gated_delta_rule.py | Adds eager + export/lower tests for the new op (skipped when FLA isn’t installed). |
| backends/cuda/tests/chunk_gated_delta_runner/main.cpp | Adds a minimal C++ program to load and run the exported model with CUDA delegate. |
| backends/cuda/tests/chunk_gated_delta_runner/CMakeLists.txt | Adds build configuration for the new C++ runner. |
| .github/workflows/cuda.yml | Installs FLA and runs the new Python + C++ e2e coverage in CUDA CI. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| V=V, | ||
| BT=BT, | ||
| BK=64, | ||
| BV=64, | ||
| USE_G=True, |
There was a problem hiding this comment.
BK and BV are hard-coded to 64 when launching recompute_w_u_fwd_kernel, which implicitly constrains supported head/value dims. Either derive these launch-time constexprs from K/V (if the FLA kernels support it) or validate/document the required K/V constraints in the @triton_op docstring and input checks to prevent confusing shape-dependent failures.
0550e5a to
43ee833
Compare
43ee833 to
8cd3f9c
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
8cd3f9c to
7d132d1
Compare
7d132d1 to
4f37bc7
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
|
||
| pte_path = os.path.join(output_dir, "chunk_gated_delta.pte") | ||
| with open(pte_path, "wb") as f: | ||
| f.write(et_program.buffer) |
There was a problem hiding this comment.
export_chunk_gated_delta() writes the program using et_program.buffer, which forces materializing the entire serialized program into a contiguous bytes object. ExecuTorchProgramManager explicitly recommends write_to_file() to avoid extra copies and reduce peak memory. Consider switching to et_program.write_to_file(f) here.
| f.write(et_program.buffer) | |
| et_program.write_to_file(f) |
4f37bc7 to
97beeeb
Compare
97beeeb to
6ee0408
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Registers FLA's chunk_gated_delta_rule as a @triton_op, following the same pattern as the existing SDPA triton kernel. Six FLA triton kernels are launched via wrap_triton() so AOTInductor compiles them directly into the generated .so — no C++ shim needed. Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't support. We unwrap via kernel.fn to get the inner @triton.autotune kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly. Requires: pip install flash-linear-attention
6ee0408 to
21c5dd7
Compare
| return o, final_state | ||
|
|
||
|
|
||
| def _make_inputs_from_fla( |
There was a problem hiding this comment.
maybe we need to test nan and inf and -inf here, follow what we learned before.
Registers FLA's chunk_gated_delta_rule as a @triton_op, following the
same pattern as the existing SDPA triton kernel. Six FLA triton kernels
are launched via wrap_triton() so AOTInductor compiles them directly
into the generated .so — no C++ shim needed.
Key trick: FLA kernels use @triton.heuristics which wrap_triton doesn't
support. We unwrap via kernel.fn to get the inner @triton.autotune
kernel and pass heuristic values (USE_G, IS_VARLEN, etc.) explicitly.
Requires: pip install flash-linear-attention