Skip to content

Commit 3f33e54

Browse files
authored
Add topk Triton kernel for CUDA backend (pytorch#18141)
Add topk Triton kernel for CUDA backend Replaces aten.topk with a Triton implementation compiled directly into the AOTInductor .so. Algorithm: iterative argmax/argmin with masking. - Replacement pass skips N > 4096 (kernel loads entire rows into one thread block); falls back to aten for vocab-sized topk - NaN handling matches torch.topk: NaN treated as larger than all finite values for both largest=True and largest=False - Handles empty dimensions (N=0, k=0) - Tests: eager correctness, NaN, empty, 3D non-last dim, export, e2e Naive implementation, slower than torch.topK ``` ┌──────────────────────────┬────────────┬─────────────┬─────────┐ │ Config │ Eager (us) │ Runner (us) │ Speedup │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=4 cols=8 k=2 │ 73.8 │ 210.4 │ 0.35x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=16 cols=8 k=2 │ 79.5 │ 224.6 │ 0.35x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=4 cols=32 k=5 │ 70.1 │ 228.0 │ 0.31x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=32 cols=64 k=10 │ 73.9 │ 299.4 │ 0.25x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=64 cols=128 k=5 │ 76.5 │ 265.2 │ 0.29x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=128 cols=256 k=10 │ 81.2 │ 239.4 │ 0.34x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=256 cols=512 k=20 │ 83.1 │ 352.0 │ 0.24x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=512 cols=32 k=2 │ 79.5 │ 258.1 │ 0.31x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=1024 cols=16 k=4 │ 75.1 │ 244.7 │ 0.31x │ ├──────────────────────────┼────────────┼─────────────┼─────────┤ │ rows=1024 cols=1024 k=10 │ 297.5 │ 623.0 │ 0.48x │ └──────────────────────────┴────────────┴─────────────┴─────────┘ ```
1 parent 1925873 commit 3f33e54

5 files changed

Lines changed: 553 additions & 5 deletions

File tree

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
# Build executor_runner (needed by CUDA backend e2e tests)
133133
cmake --build cmake-out --target executor_runner
134134
135-
# Run all CUDA backend Python tests (including chunk_gated_delta e2e)
135+
# Run CUDA backend Python tests
136136
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="
137137
138138
export-model-cuda-artifact:

backends/cuda/tests/test_topk.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Export and validate topk triton kernel on CUDA backend.
9+
10+
Usage:
11+
python -m pytest backends/cuda/tests/test_topk.py -v
12+
13+
# Standalone export (produces .pte + .ptd):
14+
python backends/cuda/tests/test_topk.py --output-dir /tmp/exports
15+
"""
16+
17+
import argparse
18+
import os
19+
import subprocess
20+
import sys
21+
import tempfile
22+
import unittest
23+
24+
import numpy as np
25+
import torch
26+
import torch.nn as nn
27+
28+
from executorch.backends.cuda.cuda_backend import CudaBackend
29+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
30+
31+
from executorch.backends.cuda.triton.kernels.topk import topk as triton_topk
32+
from executorch.exir import (
33+
EdgeCompileConfig,
34+
ExecutorchBackendConfig,
35+
to_edge_transform_and_lower,
36+
)
37+
from executorch.exir.passes import MemoryPlanningPass
38+
from torch.export import export
39+
40+
EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../.."))
41+
RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner")
42+
43+
# Test configurations: (seed, rows, cols, k, dim, largest, description)
44+
TEST_CONFIGS = [
45+
(42, 4, 8, 2, -1, True, "basic_4x8_k2"),
46+
(0, 1, 16, 3, -1, True, "single_row_k3"),
47+
(7, 8, 4, 1, -1, True, "8x4_k1"),
48+
(99, 4, 8, 2, -1, False, "smallest_k2"),
49+
(13, 2, 32, 5, -1, True, "wide_k5"),
50+
(55, 4, 8, 8, -1, True, "k_equals_n"),
51+
(77, 1, 4, 2, -1, True, "tiny_1x4_k2"),
52+
(123, 16, 8, 2, -1, True, "many_rows"),
53+
]
54+
55+
56+
class TopKModel(nn.Module):
57+
"""Linear projection followed by topk."""
58+
59+
def __init__(self, dim_in=8, k=2, topk_dim=-1, largest=True):
60+
super().__init__()
61+
self.linear = nn.Linear(dim_in, dim_in, bias=False)
62+
self.k = k
63+
self.topk_dim = topk_dim
64+
self.largest = largest
65+
66+
def forward(self, x):
67+
x = self.linear(x)
68+
values, indices = torch.topk(x, self.k, dim=self.topk_dim, largest=self.largest)
69+
return values, indices
70+
71+
72+
def _make_inputs(seed, rows, cols, dtype=torch.bfloat16, device="cuda"):
73+
torch.manual_seed(seed)
74+
return (torch.randn(rows, cols, dtype=dtype, device=device),)
75+
76+
77+
def _save_tensor(t, path):
78+
t_cpu = t.cpu().contiguous()
79+
with open(path, "wb") as f:
80+
f.write(bytes(t_cpu.untyped_storage()))
81+
82+
83+
def _load_output(path, shape, dtype):
84+
data = np.fromfile(path, dtype=np.uint8)
85+
return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape)
86+
87+
88+
def export_topk(output_dir):
89+
"""Export a TopKModel (rows=4, cols=8, k=2, largest=True) to .pte + .ptd."""
90+
torch.manual_seed(42)
91+
model = (
92+
TopKModel(dim_in=8, k=2, largest=True)
93+
.to(device="cuda", dtype=torch.bfloat16)
94+
.eval()
95+
)
96+
inputs = _make_inputs(42, 4, 8)
97+
98+
with torch.no_grad():
99+
ep = export(model, inputs, strict=True)
100+
101+
os.makedirs(output_dir, exist_ok=True)
102+
103+
specs = [CudaBackend.generate_method_name_compile_spec("forward")]
104+
et_prog = to_edge_transform_and_lower(
105+
ep,
106+
partitioner=[CudaPartitioner(specs)],
107+
compile_config=EdgeCompileConfig(
108+
_check_ir_validity=False, _skip_dim_order=True
109+
),
110+
)
111+
et_program = et_prog.to_executorch(
112+
config=ExecutorchBackendConfig(
113+
extract_delegate_segments=True,
114+
do_quant_fusion_and_const_prop=True,
115+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
116+
),
117+
)
118+
119+
pte_path = os.path.join(output_dir, "topk.pte")
120+
with open(pte_path, "wb") as f:
121+
et_program.write_to_file(f)
122+
123+
if hasattr(et_program, "_tensor_data") and et_program._tensor_data:
124+
et_program.write_tensor_data_to_file(output_dir)
125+
126+
return pte_path, model
127+
128+
129+
def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
130+
"""Run executor_runner and return subprocess result."""
131+
cmd = [
132+
runner_path,
133+
f"--model_path={pte_path}",
134+
f"--data_path={ptd_path}",
135+
f"--inputs={','.join(input_files)}",
136+
f"--output_file={output_base}",
137+
]
138+
return subprocess.run(cmd, capture_output=True, text=True)
139+
140+
141+
class TestTopK(unittest.TestCase):
142+
def setUp(self):
143+
if not torch.cuda.is_available():
144+
self.skipTest("CUDA is not available")
145+
146+
def test_eager(self):
147+
"""Triton topk produces correct shapes and dtypes."""
148+
x = torch.randn(4, 8, dtype=torch.bfloat16, device="cuda")
149+
vals, idx = triton_topk(x, 2)
150+
self.assertEqual(vals.shape, torch.Size([4, 2]))
151+
self.assertEqual(idx.shape, torch.Size([4, 2]))
152+
self.assertEqual(vals.dtype, torch.bfloat16)
153+
self.assertEqual(idx.dtype, torch.int64)
154+
155+
def test_eager_correctness(self):
156+
"""Triton topk matches torch.topk across multiple configs."""
157+
for seed, rows, cols, k, dim, largest, desc in TEST_CONFIGS:
158+
with self.subTest(desc=desc):
159+
torch.manual_seed(seed)
160+
x = torch.randn(rows, cols, dtype=torch.bfloat16, device="cuda")
161+
162+
ref_vals, ref_idx = torch.topk(x, k, dim=dim, largest=largest)
163+
tri_vals, tri_idx = triton_topk(x, k, dim=dim, largest=largest)
164+
165+
v_diff = (tri_vals.float() - ref_vals.float()).abs().max().item()
166+
self.assertLess(v_diff, 1e-3, f"{desc}: value diff {v_diff}")
167+
self.assertTrue(
168+
torch.equal(tri_idx, ref_idx),
169+
f"{desc}: indices mismatch",
170+
)
171+
172+
def test_empty_dimension(self):
173+
"""N=0 with k=0 returns empty tensors (matches torch.topk)."""
174+
x = torch.empty(4, 0, dtype=torch.bfloat16, device="cuda")
175+
vals, idx = triton_topk(x, 0, dim=-1)
176+
ref_vals, ref_idx = torch.topk(x, 0, dim=-1)
177+
self.assertEqual(vals.shape, ref_vals.shape)
178+
self.assertEqual(idx.shape, ref_idx.shape)
179+
180+
def test_nan_handling(self):
181+
"""NaN treated as larger than all finite values (matches torch.topk)."""
182+
cases = [
183+
("all_nan_largest", [float("nan")] * 3, 2, True),
184+
("mixed_largest", [1.0, float("nan"), 3.0, float("nan"), 2.0], 3, True),
185+
("mixed_smallest", [1.0, float("nan"), 3.0, float("nan"), 2.0], 3, False),
186+
(
187+
"mixed_smallest_all",
188+
[1.0, float("nan"), 3.0, float("nan"), 2.0],
189+
5,
190+
False,
191+
),
192+
]
193+
for desc, data, k, largest in cases:
194+
with self.subTest(desc=desc):
195+
x = torch.tensor([data], dtype=torch.float32, device="cuda")
196+
tv, ti = triton_topk(x, k, largest=largest)
197+
rv, ri = torch.topk(x, k, largest=largest)
198+
199+
# NaN count must match
200+
self.assertEqual(
201+
tv.isnan().sum().item(),
202+
rv.isnan().sum().item(),
203+
f"{desc}: NaN count mismatch",
204+
)
205+
# Finite values and indices must match
206+
tv_finite = tv[~tv.isnan()]
207+
rv_finite = rv[~rv.isnan()]
208+
if tv_finite.numel() > 0:
209+
v_diff = (tv_finite - rv_finite).abs().max().item()
210+
self.assertLess(v_diff, 1e-3, f"{desc}: value diff {v_diff}")
211+
self.assertTrue(
212+
torch.equal(ti[~tv.isnan()], ri[~rv.isnan()]),
213+
f"{desc}: finite indices mismatch",
214+
)
215+
216+
def test_3d_non_last_dim(self):
217+
"""Topk on non-last dimension of 3D tensor."""
218+
torch.manual_seed(42)
219+
x = torch.randn(2, 5, 3, dtype=torch.bfloat16, device="cuda")
220+
tv, ti = triton_topk(x, 2, dim=1)
221+
rv, ri = torch.topk(x, 2, dim=1)
222+
self.assertEqual(tv.shape, rv.shape)
223+
v_diff = (tv.float() - rv.float()).abs().max().item()
224+
self.assertLess(v_diff, 1e-3)
225+
self.assertTrue(torch.equal(ti, ri))
226+
227+
def test_export_cuda(self):
228+
"""Export succeeds and produces non-empty .pte."""
229+
with tempfile.TemporaryDirectory() as tmpdir:
230+
pte_path, _ = export_topk(tmpdir)
231+
self.assertTrue(os.path.exists(pte_path))
232+
self.assertGreater(os.path.getsize(pte_path), 0)
233+
234+
def test_e2e_cpp_runner(self):
235+
"""Export once, run executor_runner with multiple inputs, compare."""
236+
self.assertTrue(
237+
os.path.exists(RUNNER_PATH),
238+
f"executor_runner not found at {RUNNER_PATH}. "
239+
"Build with: cmake --build cmake-out --target executor_runner",
240+
)
241+
242+
# Exported model: rows=4, cols=8, k=2, largest=True
243+
rows, cols, k = 4, 8, 2
244+
e2e_seeds = [0, 7, 42, 99, 123, 2024]
245+
246+
with tempfile.TemporaryDirectory() as tmpdir:
247+
export_dir = os.path.join(tmpdir, "export")
248+
pte_path, model = export_topk(export_dir)
249+
ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd")
250+
251+
for seed in e2e_seeds:
252+
with self.subTest(seed=seed):
253+
inputs = _make_inputs(seed, rows, cols)
254+
255+
with torch.no_grad():
256+
ref_vals, ref_idx = model(*inputs)
257+
258+
run_dir = os.path.join(tmpdir, f"run_seed{seed}")
259+
os.makedirs(run_dir)
260+
261+
input_files = []
262+
for i, tensor in enumerate(inputs):
263+
path = os.path.join(run_dir, f"{i}.bin")
264+
_save_tensor(tensor, path)
265+
input_files.append(path)
266+
267+
output_base = os.path.join(run_dir, "output")
268+
result = _run_cpp_runner(
269+
RUNNER_PATH, pte_path, ptd_path, input_files, output_base
270+
)
271+
self.assertEqual(
272+
result.returncode,
273+
0,
274+
f"seed={seed}: executor_runner failed:\n{result.stderr}",
275+
)
276+
277+
cpp_vals = _load_output(
278+
f"{output_base}-0.bin",
279+
(rows, k),
280+
torch.bfloat16,
281+
)
282+
cpp_idx = _load_output(
283+
f"{output_base}-1.bin",
284+
(rows, k),
285+
torch.int64,
286+
)
287+
288+
v_diff = (
289+
(cpp_vals.float() - ref_vals.cpu().float()).abs().max().item()
290+
)
291+
self.assertLess(v_diff, 0.01, f"seed={seed}: value diff {v_diff}")
292+
self.assertTrue(
293+
torch.equal(cpp_idx, ref_idx.cpu()),
294+
f"seed={seed}: indices mismatch\n"
295+
f" cpp: {cpp_idx}\n ref: {ref_idx.cpu()}",
296+
)
297+
298+
299+
if __name__ == "__main__":
300+
parser = argparse.ArgumentParser()
301+
parser.add_argument("--output-dir", default=None)
302+
args, remaining = parser.parse_known_args()
303+
304+
if args.output_dir:
305+
export_topk(args.output_dir)
306+
else:
307+
sys.argv = [sys.argv[0]] + remaining
308+
unittest.main()

backends/cuda/triton/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from executorch.backends.cuda.triton.kernels.sdpa import sdpa
8+
from executorch.backends.cuda.triton.kernels.topk import topk
89

910
__all__ = [
1011
"sdpa",
12+
"topk",
1113
]
1214

1315
try:

0 commit comments

Comments
 (0)