Skip to content

Add topk Triton kernel for CUDA backend#18141

Open
mergennachin wants to merge 2 commits intomainfrom
mergennachin/topk-triton-kernel
Open

Add topk Triton kernel for CUDA backend#18141
mergennachin wants to merge 2 commits intomainfrom
mergennachin/topk-triton-kernel

Conversation

@mergennachin
Copy link
Contributor

Add a Triton-based topk kernel that replaces aten.topk during graph
transformation, compiled directly into the AOTInductor .so via
wrap_triton (no C++ fallback shim needed).

The kernel uses iterative argmax with masking, adapted from
FlagGems/aiter. It is registered via @triton_op("triton::topk") and
auto-substituted for aten.topk.default through ReplaceEdgeOpWithTritonOpPass.

Tests follow the chunk_gated_delta_rule pattern: eager correctness
across 8 configs, export validation, and E2E C++ runner comparison.

Copilot AI review requested due to automatic review settings March 12, 2026 22:31
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 12, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18141

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 8 Unrelated Failures

As of commit b29e89e with merge base cc27e6b (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 12, 2026
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Triton-based topk kernel for the ExecuTorch CUDA backend, replacing aten.topk.default during graph transformation. The kernel uses iterative argmax/argmin with masking and is registered via @triton_op.

Changes:

  • New Triton topk kernel implementation with iterative max/min and masking algorithm
  • Registration of the kernel in the edge-to-triton replacement pass
  • Tests (eager correctness, export validation, E2E C++ runner) and a dedicated C++ test runner

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
backends/cuda/triton/kernels/topk.py New Triton topk kernel and its abstract/fake implementation
backends/cuda/triton/kernels/init.py Export the new topk symbol
backends/cuda/triton/replacement_pass.py Map aten.topk.default to the Triton kernel
backends/cuda/tests/test_topk.py Eager correctness, export, and E2E tests
backends/cuda/tests/topk_runner/main.cpp C++ runner for E2E testing
backends/cuda/tests/topk_runner/CMakeLists.txt Build config for the C++ runner

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Add a Triton-based topk kernel that replaces aten.topk during graph
transformation, compiled directly into the AOTInductor .so via
wrap_triton (no C++ fallback shim needed).

The kernel uses iterative argmax with masking, adapted from
FlagGems/aiter. It is registered via @triton_op("triton::topk") and
auto-substituted for aten.topk.default through ReplaceEdgeOpWithTritonOpPass.

Tests follow the chunk_gated_delta_rule pattern: eager correctness
across 8 configs, export validation, and E2E C++ runner comparison.

This PR was authored with the assistance of Claude Code.
@mergennachin mergennachin force-pushed the mergennachin/topk-triton-kernel branch from fb5d204 to 00165ab Compare March 12, 2026 22:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants