Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18141
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 8 Unrelated FailuresAs of commit b29e89e with merge base cc27e6b ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a Triton-based topk kernel for the ExecuTorch CUDA backend, replacing aten.topk.default during graph transformation. The kernel uses iterative argmax/argmin with masking and is registered via @triton_op.
Changes:
- New Triton topk kernel implementation with iterative max/min and masking algorithm
- Registration of the kernel in the edge-to-triton replacement pass
- Tests (eager correctness, export validation, E2E C++ runner) and a dedicated C++ test runner
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| backends/cuda/triton/kernels/topk.py | New Triton topk kernel and its abstract/fake implementation |
| backends/cuda/triton/kernels/init.py | Export the new topk symbol |
| backends/cuda/triton/replacement_pass.py | Map aten.topk.default to the Triton kernel |
| backends/cuda/tests/test_topk.py | Eager correctness, export, and E2E tests |
| backends/cuda/tests/topk_runner/main.cpp | C++ runner for E2E testing |
| backends/cuda/tests/topk_runner/CMakeLists.txt | Build config for the C++ runner |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Add a Triton-based topk kernel that replaces aten.topk during graph
transformation, compiled directly into the AOTInductor .so via
wrap_triton (no C++ fallback shim needed).
The kernel uses iterative argmax with masking, adapted from
FlagGems/aiter. It is registered via @triton_op("triton::topk") and
auto-substituted for aten.topk.default through ReplaceEdgeOpWithTritonOpPass.
Tests follow the chunk_gated_delta_rule pattern: eager correctness
across 8 configs, export validation, and E2E C++ runner comparison.
This PR was authored with the assistance of Claude Code.
fb5d204 to
00165ab
Compare
Add a Triton-based topk kernel that replaces aten.topk during graph
transformation, compiled directly into the AOTInductor .so via
wrap_triton (no C++ fallback shim needed).
The kernel uses iterative argmax with masking, adapted from
FlagGems/aiter. It is registered via @triton_op("triton::topk") and
auto-substituted for aten.topk.default through ReplaceEdgeOpWithTritonOpPass.
Tests follow the chunk_gated_delta_rule pattern: eager correctness
across 8 configs, export validation, and E2E C++ runner comparison.