diff --git a/backends/cuda/tests/test_topk.py b/backends/cuda/tests/test_topk.py new file mode 100644 index 00000000000..eca9051c691 --- /dev/null +++ b/backends/cuda/tests/test_topk.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export and validate topk triton kernel on CUDA backend. + +Usage: + python -m pytest backends/cuda/tests/test_topk.py -v + + # Standalone export (produces .pte + .ptd): + python backends/cuda/tests/test_topk.py --output-dir /tmp/exports +""" + +import argparse +import os +import subprocess +import sys +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn as nn + +from executorch.backends.cuda.cuda_backend import CudaBackend +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from torch.export import export + + +RUNNER_PATH = os.path.join( + os.path.dirname(__file__), + "../../../cmake-out/backends/cuda/tests/topk_runner/topk_runner", +) + +# Test configurations: (seed, rows, cols, k, dim, largest, description) +TEST_CONFIGS = [ + (42, 4, 8, 2, -1, True, "basic_4x8_k2"), + (0, 1, 16, 3, -1, True, "single_row_k3"), + (7, 8, 4, 1, -1, True, "8x4_k1"), + (99, 4, 8, 2, -1, False, "smallest_k2"), + (13, 2, 32, 5, -1, True, "wide_k5"), + (55, 4, 8, 8, -1, True, "k_equals_n"), + (77, 1, 4, 2, -1, True, "tiny_1x4_k2"), + (123, 16, 8, 2, -1, True, "many_rows"), +] + + +class TopKModel(nn.Module): + """Linear projection followed by topk.""" + + def __init__(self, dim_in=8, k=2, topk_dim=-1, largest=True): + super().__init__() + self.linear = nn.Linear(dim_in, dim_in, bias=False) + self.k = k + self.topk_dim = topk_dim + self.largest = largest + + def forward(self, x): + x = self.linear(x) + values, indices = torch.topk(x, self.k, dim=self.topk_dim, largest=self.largest) + return values, indices + + +def _make_inputs(seed, rows, cols, dtype=torch.bfloat16, device="cuda"): + torch.manual_seed(seed) + return (torch.randn(rows, cols, dtype=dtype, device=device),) + + +def _save_tensor(t, path): + t_cpu = t.cpu().contiguous() + with open(path, "wb") as f: + f.write(bytes(t_cpu.untyped_storage())) + + +def _load_output(path, shape, dtype): + data = np.fromfile(path, dtype=np.uint8) + return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape) + + +def export_topk(output_dir, cols=8, k=2, largest=True): + """Export a TopKModel to .pte + .ptd.""" + torch.manual_seed(42) + model = ( + TopKModel(dim_in=cols, k=k, largest=largest) + .to(device="cuda", dtype=torch.bfloat16) + .eval() + ) + inputs = _make_inputs(42, 4, cols) + + with torch.no_grad(): + ref_vals, ref_idx = model(*inputs) + print(f"Eager output: values {ref_vals.shape}, indices {ref_idx.shape}") + + with torch.no_grad(): + ep = export(model, inputs, strict=True) + print("Export OK") + + os.makedirs(output_dir, exist_ok=True) + + specs = [CudaBackend.generate_method_name_compile_spec("forward")] + et_prog = to_edge_transform_and_lower( + ep, + partitioner=[CudaPartitioner(specs)], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, _skip_dim_order=True + ), + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + pte_path = os.path.join(output_dir, "topk.pte") + with open(pte_path, "wb") as f: + f.write(et_program.buffer) + + if hasattr(et_program, "_tensor_data") and et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + + print(f"Saved to {pte_path} ({os.path.getsize(pte_path) / 1024:.0f} KB)") + return pte_path, model + + +def _run_cpp_runner(runner_path, pte_path, ptd_path, input_dir, output_dir): + cmd = [ + runner_path, + f"--model_path={pte_path}", + f"--data_path={ptd_path}", + f"--input_dir={input_dir}", + f"--output_dir={output_dir}", + ] + return subprocess.run(cmd, capture_output=True, text=True) + + +class TestTopK(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available") + + def test_eager(self): + """Triton topk produces correct shapes and dtypes.""" + model = TopKModel().to(device="cuda", dtype=torch.bfloat16).eval() + inputs = _make_inputs(42, 4, 8) + with torch.no_grad(): + vals, idx = model(*inputs) + self.assertEqual(vals.shape, torch.Size([4, 2])) + self.assertEqual(idx.shape, torch.Size([4, 2])) + self.assertEqual(vals.dtype, torch.bfloat16) + self.assertEqual(idx.dtype, torch.int64) + + def test_eager_correctness(self): + """Triton topk matches torch.topk across multiple configs.""" + for seed, rows, cols, k, dim, largest, desc in TEST_CONFIGS: + with self.subTest(desc=desc): + torch.manual_seed(seed) + x = torch.randn(rows, cols, dtype=torch.bfloat16, device="cuda") + + ref_vals, ref_idx = torch.topk(x, k, dim=dim, largest=largest) + + from executorch.backends.cuda.triton.kernels.topk import ( + topk as triton_topk, + ) + + tri_vals, tri_idx = triton_topk(x, k, dim=dim, largest=largest) + + v_diff = (tri_vals.float() - ref_vals.float()).abs().max().item() + self.assertLess(v_diff, 1e-3, f"{desc}: value diff {v_diff}") + self.assertTrue( + torch.equal(tri_idx, ref_idx), + f"{desc}: indices mismatch", + ) + + def test_export_cuda(self): + """Export succeeds and produces non-empty .pte.""" + with tempfile.TemporaryDirectory() as tmpdir: + pte_path, _ = export_topk(tmpdir) + self.assertTrue(os.path.exists(pte_path)) + self.assertGreater(os.path.getsize(pte_path), 0) + + @unittest.skipUnless(os.path.exists(RUNNER_PATH), "C++ runner not built") + def test_e2e_cpp_runner(self): + """Export, run C++ runner, compare with eager.""" + with tempfile.TemporaryDirectory() as tmpdir: + export_dir = os.path.join(tmpdir, "export") + pte_path, model = export_topk(export_dir) + ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd") + + for seed, rows, cols, k, _dim, largest, desc in TEST_CONFIGS: + # Skip configs that don't match the exported model shape + if cols != 8 or k != 2 or not largest or rows != 4: + continue + + with self.subTest(desc=desc): + inputs = _make_inputs(seed, rows, cols) + + with torch.no_grad(): + ref_vals, ref_idx = model(*inputs) + + input_dir = os.path.join(tmpdir, f"inputs_{desc}") + output_dir = os.path.join(tmpdir, f"outputs_{desc}") + os.makedirs(input_dir) + os.makedirs(output_dir) + + _save_tensor(inputs[0], os.path.join(input_dir, "x.bin")) + + result = _run_cpp_runner( + RUNNER_PATH, pte_path, ptd_path, input_dir, output_dir + ) + self.assertEqual( + result.returncode, + 0, + f"{desc}: C++ runner failed:\n{result.stderr}", + ) + + cpp_vals = _load_output( + os.path.join(output_dir, "output_0.bin"), + (rows, k), + torch.bfloat16, + ) + cpp_idx = _load_output( + os.path.join(output_dir, "output_1.bin"), + (rows, k), + torch.int64, + ) + + v_diff = ( + (cpp_vals.float() - ref_vals.cpu().float()).abs().max().item() + ) + self.assertLess(v_diff, 0.01, f"{desc}: value diff {v_diff}") + self.assertTrue( + torch.equal(cpp_idx, ref_idx.cpu()), + f"{desc}: indices mismatch\n" + f" cpp: {cpp_idx}\n ref: {ref_idx.cpu()}", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", default=None) + args, remaining = parser.parse_known_args() + + if args.output_dir: + export_topk(args.output_dir) + else: + sys.argv = [sys.argv[0]] + remaining + unittest.main() diff --git a/backends/cuda/tests/topk_runner/CMakeLists.txt b/backends/cuda/tests/topk_runner/CMakeLists.txt new file mode 100644 index 00000000000..6936577593b --- /dev/null +++ b/backends/cuda/tests/topk_runner/CMakeLists.txt @@ -0,0 +1,52 @@ +cmake_minimum_required(VERSION 3.24) +project(topk_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../third-party/gflags) +find_package(gflags REQUIRED) + +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +list( + APPEND + link_libraries + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor + extension_named_data_map +) + +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +add_executable(topk_runner main.cpp) +target_include_directories(topk_runner PUBLIC ${_common_include_directories}) +target_link_libraries(topk_runner PUBLIC ${link_libraries}) + +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(topk_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(topk_runner PRIVATE "LINKER:-s") + endif() +endif() diff --git a/backends/cuda/tests/topk_runner/main.cpp b/backends/cuda/tests/topk_runner/main.cpp new file mode 100644 index 00000000000..2389c0d2c1e --- /dev/null +++ b/backends/cuda/tests/topk_runner/main.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include + +DEFINE_string(model_path, "", "Path to .pte file"); +DEFINE_string(data_path, "", "Path to .ptd file (for CUDA delegate)"); +DEFINE_string(input_dir, "", "Directory with input .bin files"); +DEFINE_string(output_dir, "", "Directory to write output .bin files"); + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +static std::vector read_file(const std::string& path) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f) { + fprintf(stderr, "Cannot open %s\n", path.c_str()); + exit(1); + } + std::size_t size = static_cast(f.tellg()); + f.seekg(0); + std::vector buf(size); + f.read(buf.data(), static_cast(size)); + return buf; +} + +static void write_file(const std::string& path, const void* data, size_t len) { + std::ofstream f(path, std::ios::binary); + f.write(static_cast(data), len); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_model_path.empty()) { + fprintf(stderr, "Error: --model_path required\n"); + return 1; + } + + std::unique_ptr module; + if (!FLAGS_data_path.empty()) { + module = std::make_unique( + FLAGS_model_path, + FLAGS_data_path, + Module::LoadMode::MmapUseMlockIgnoreErrors); + } else { + module = std::make_unique( + FLAGS_model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + } + + auto load_err = module->load(); + if (load_err != Error::Ok) { + fprintf(stderr, "Failed to load model: 0x%x\n", static_cast(load_err)); + return 1; + } + + std::vector inputs; + + if (!FLAGS_input_dir.empty()) { + std::string path = FLAGS_input_dir + "/x.bin"; + static std::vector input_buf = read_file(path); + + // Infer rows from file size: each row is 8 bf16 elements = 16 bytes + constexpr int kCols = 8; + constexpr int kElemSize = 2; // bf16 + int rows = static_cast(input_buf.size()) / (kCols * kElemSize); + + static executorch::extension::TensorPtr input_tensor; + input_tensor = from_blob( + input_buf.data(), {rows, kCols}, exec_aten::ScalarType::BFloat16); + inputs.push_back(*input_tensor); + } else { + fprintf(stderr, "Error: --input_dir required\n"); + return 1; + } + + auto result = module->execute("forward", inputs); + if (!result.ok()) { + fprintf(stderr, "Forward failed: 0x%x\n", static_cast(result.error())); + return 1; + } + + auto outputs = result.get(); + for (size_t i = 0; i < outputs.size(); i++) { + if (!outputs[i].isTensor()) + continue; + const auto& t = outputs[i].toTensor(); + printf("Output %zu: [", i); + for (int d = 0; d < t.dim(); d++) + printf("%d%s", static_cast(t.size(d)), d < t.dim() - 1 ? "," : ""); + printf("] dtype=%d\n", static_cast(t.scalar_type())); + + if (!FLAGS_output_dir.empty()) { + std::string path = + FLAGS_output_dir + "/output_" + std::to_string(i) + ".bin"; + write_file(path, t.const_data_ptr(), t.nbytes()); + printf(" Saved to %s (%zu bytes)\n", path.c_str(), (size_t)t.nbytes()); + } + } + + printf("SUCCESS\n"); + return 0; +} diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index 5bd582679c4..15acff87b49 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.cuda.triton.kernels.sdpa import sdpa +from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ "sdpa", + "topk", ] diff --git a/backends/cuda/triton/kernels/topk.py b/backends/cuda/triton/kernels/topk.py new file mode 100644 index 00000000000..ae846b3b017 --- /dev/null +++ b/backends/cuda/triton/kernels/topk.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Triton Top-K Kernel for ExecuTorch CUDA Backend. + +Replaces aten.topk with a Triton implementation so the op is compiled +directly into the AOTInductor .so (no C++ fallback shim needed). + +Algorithm: iterative argmax/argmin with masking, adapted from +FlagGems / aiter topk (1-stage path for moderate row sizes). +""" + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +def _next_power_of_2(x: int) -> int: + """Return the smallest power of 2 >= x.""" + n = 1 + while n < x: + n *= 2 + return n + + +@triton.jit +def _topk_kernel( + X, + OUT_V, + OUT_I, + stride_xn, + stride_ovn, + stride_oin, + N: tl.constexpr, + K: tl.constexpr, + BLOCK: tl.constexpr, + LARGEST: tl.constexpr, +): + """Single-block topk: one program per row, iterative max/min with masking.""" + pid = tl.program_id(0) + row_ptr = X + pid * stride_xn + offs = tl.arange(0, BLOCK) + mask = offs < N + + if LARGEST: + FILL: tl.constexpr = float("-inf") + else: + FILL: tl.constexpr = float("inf") + + vals = tl.load(row_ptr + offs, mask=mask, other=FILL).to(tl.float32) + idxs = offs.to(tl.int64) + + out_v_ptr = OUT_V + pid * stride_ovn + out_i_ptr = OUT_I + pid * stride_oin + + for j in tl.static_range(0, K): + if LARGEST: + vsel = tl.max(vals, axis=0) + else: + vsel = tl.min(vals, axis=0) + + eq = vals == vsel + # For ties, pick the smallest index: add BLOCK to non-equal positions + big = tl.where(eq, tl.zeros_like(idxs), tl.zeros_like(idxs) + BLOCK) + arg = tl.min(idxs + big, axis=0) + + tl.store(out_v_ptr + j, vsel) + tl.store(out_i_ptr + j, arg) + + # Mask out the selected element + vals = tl.where(idxs == arg, FILL, vals) + + +@triton_op("triton::topk", mutates_args={}) +def topk( + self: torch.Tensor, + k: int, + dim: int = -1, + largest: bool = True, + sorted: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Triton top-k implementation. + + Supports arbitrary dim by transposing so the target dimension is last, + running the kernel, then transposing back. + + Args: + self: Input tensor on CUDA. + k: Number of top elements. + dim: Dimension to operate on (default -1). + largest: If True return largest, else smallest. + sorted: If True return in sorted order (inherent in iterative algo). + + Returns: + (values, indices) tensors with the topk dimension replaced by k. + """ + # Normalize dim + ndim = self.dim() + if dim < 0: + dim = dim + ndim + + # Move target dim to last position for contiguous row access + if dim != ndim - 1: + self = self.transpose(dim, ndim - 1).contiguous() + elif not self.is_contiguous(): + self = self.contiguous() + + # Flatten all batch dims into one + orig_shape = self.shape + N = orig_shape[-1] # row length + num_rows = self.numel() // N + x_flat = self.reshape(num_rows, N) + + # Allocate outputs + values = torch.empty(num_rows, k, dtype=self.dtype, device=self.device) + indices = torch.empty(num_rows, k, dtype=torch.int64, device=self.device) + + if k == 0 or num_rows == 0: + # Reshape and transpose back + out_shape = list(orig_shape) + out_shape[-1] = k + values = values.reshape(out_shape) + indices = indices.reshape(out_shape) + if dim != ndim - 1: + values = values.transpose(dim, ndim - 1).contiguous() + indices = indices.transpose(dim, ndim - 1).contiguous() + return values, indices + + BLOCK = _next_power_of_2(N) + + grid = (num_rows,) + wrap_triton(_topk_kernel)[grid]( + x_flat, + values, + indices, + x_flat.stride(0), + values.stride(0), + indices.stride(0), + N=N, + K=k, + BLOCK=BLOCK, + LARGEST=largest, + ) + + # Reshape back to original batch shape with k replacing dim size + out_shape = list(orig_shape) + out_shape[-1] = k + values = values.reshape(out_shape) + indices = indices.reshape(out_shape) + + # Transpose back if we moved dim + if dim != ndim - 1: + values = values.transpose(dim, ndim - 1).contiguous() + indices = indices.transpose(dim, ndim - 1).contiguous() + + return values, indices + + +@topk.register_fake +def _topk_abstract( + self: torch.Tensor, + k: int, + dim: int = -1, + largest: bool = True, + sorted: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Abstract/fake implementation for torch.export.""" + ndim = self.dim() + if dim < 0: + dim = dim + ndim + out_shape = list(self.shape) + out_shape[dim] = k + values = self.new_empty(out_shape) + indices = self.new_empty(out_shape, dtype=torch.int64) + return values, indices diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index bfa3838296b..9b7ecf75336 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -24,6 +24,7 @@ # Global mapping from edge dialect operators to Triton kernel functions EDGE_TO_TRITON_KERNELS = { exir_ops.edge.aten.scaled_dot_product_attention.default: triton.sdpa, + exir_ops.edge.aten.topk.default: triton.topk, }