Skip to content
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a4ac5c4
docs: add TileIR backend usage guide to helion-hackathon.md
yf225 Mar 13, 2026
8bbd20d
docs: restructure ACF + TileIR as optional performance knobs
yf225 Mar 13, 2026
4882ec0
docs: remove "(Booster Pack)" from ACF heading
yf225 Mar 13, 2026
f10d235
docs: consolidate TileIR env var instructions
yf225 Mar 13, 2026
e4cb531
docs: clarify TileIR tunables come from autotuner output
yf225 Mar 13, 2026
115daeb
docs: shorten "Which should I use?" section
yf225 Mar 13, 2026
2b1bef6
docs: be explicit about ENABLE_TILE=0 vs ENABLE_TILE=1
yf225 Mar 13, 2026
1a127c1
docs: simplify TileIR comparison table to just backend names
yf225 Mar 13, 2026
10c9ba6
docs: add scoring system, rules, and open-ended contribution track
yf225 Mar 13, 2026
19c69ac
docs: allow unlimited submissions, best one counts
yf225 Mar 13, 2026
4fac3a8
docs: clarify rules to match actual submission format
yf225 Mar 13, 2026
eaeebed
Revert "docs: clarify rules to match actual submission format"
yf225 Mar 13, 2026
7cb3b40
Add per-shape config dispatch pattern to all submissions
yf225 Mar 13, 2026
36aaba5
docs: update example to show all shapes, remove DEFAULT_CONFIG
yf225 Mar 13, 2026
b1aacca
docs: use Config(...) placeholders with distinct TODO comments for te…
yf225 Mar 13, 2026
30159b1
docs: remove references to single-config-for-all-shapes pattern
yf225 Mar 13, 2026
f143d61
docs: remove references to default config in rules section
yf225 Mar 13, 2026
92fb7f7
docs: add tips for version control, tmux, and machine reboots
yf225 Mar 14, 2026
c608d4b
docs: move GPU machine tips to standalone section
yf225 Mar 14, 2026
5f814e2
docs: fix performance metric description to match actual eval method
yf225 Mar 14, 2026
4fa0303
Replace hard 30% LOC limit with judges' discretion for inline triton/asm
yf225 Mar 14, 2026
6a2a409
Add spawn mode tip for autotuning in GPU machine section
yf225 Mar 14, 2026
3ba5b4a
Clarify that spawn mode is slower than fork mode
yf225 Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 157 additions & 54 deletions docs/helion-hackathon.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@ Submit [Helion](https://github.com/pytorch/helion) kernels to the GPU MODE leade
| 4 | `gated_deltanet_chunk_fwd_o` | Output computation for Gated DeltaNet |
| 5 | `gated_deltanet_recompute_w_u` | WY-transform forward kernel for Gated DeltaNet |

## Scoring

### Point Allocation

| Kernel | Correctness Points | Performance Points |
|---|---|---|
| **FP8 Quantization** | 100 | 0 (unscored) |
| **Causal Depthwise 1D Convolution** | 100 | 1000 |
| **Gated DeltaNet chunk\_fwd\_h** | 100 | 1000 |
| **Gated DeltaNet chunk\_fwd\_o** | 100 | 1000 |
| **Gated DeltaNet recompute\_w\_u** | 100 | 1000 |

### Scoring Rules

- **Performance Metric**: For each benchmark shape, the kernel is captured in a CUDA graph and replayed with L2 cache clearing before each invocation. The graph unrolls enough calls to fill ~100ms of GPU time, and this is repeated 10 times. The runtime is the arithmetic mean of those 10 measurements.
- **Ranking**: Participants are ranked per kernel by runtime (fastest = rank 1).
- **Formula**: Score = CorrectnessPoints + (PerformancePoints × [1 − (rank - 1) / 10])
- CorrectnessPoints are earned if the submission passes all test input shapes.
- Only the top 10 performers per kernel (who pass all tests) can earn PerformancePoints.
- Rank 1 → 100% of PerformancePoints, Rank 2 → 90%, …, Rank 10 → 10%.
- **Tiebreaker**: If two participants have the same metric value, the earlier submission wins.
- **Test case shapes**: Provided in `task.yml`; input data sampled from a random distribution.

**Total score** = Sum of points for all kernels.

## Rules & Requirements

- Kernel must pass all test input shapes (numerical accuracy within tolerance) with participant-provided config
- All benchmark shapes must have their best configs submitted for that kernel to be scored
- Implementations must use Helion DSL. `hl.inline_triton()`, `hl.triton_kernel()`, and `hl.inline_asm_elementwise()` are allowed as escape hatches, but the majority of your kernel should be written in Helion. Submissions that are predominantly inline Triton/ASM with a thin Helion wrapper may be disqualified at judges' discretion
- Unlimited submissions per participant per kernel. Only your best submission counts. Each submission should include: your Helion kernel implementation, one config per test input shape, and one best autotuned config per benchmark input shape

## Quick Start

```bash
Expand Down Expand Up @@ -76,63 +108,77 @@ Replace `causal_conv1d_py/` with any problem directory.

Your submission must be a single Python file that defines `custom_kernel(data: input_t) -> output_t`. To use Helion, write a `@helion.kernel` decorated function and call it from `custom_kernel`.

Here's an example structure for `causal_conv1d`:
Use **per-shape configs** to optimize for each benchmark shape independently. The per-shape config pattern uses a factory function to create kernel variants with different configs, and dispatches based on input tensor shapes:

```python
from task import input_t, output_t
import torch
import helion
import helion.language as hl

@helion.kernel(config=helion.Config(
block_sizes=[64, 64],
num_warps=4,
num_stages=3,
# ... your tuned config here
))
def causal_conv1d_kernel(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
# Your Helion kernel implementation
...
# Map input shapes to optimized configs (autotune each shape locally).
# Include all test and benchmark shapes from task.yml.
SHAPE_CONFIGS: dict[tuple, helion.Config] = {
# Test shapes
(1, 64, 64, 4): helion.Config(...), # TODO: replace with default config or any config that passes correctness check
(2, 128, 128, 4): helion.Config(...), # TODO: replace with default config or any config that passes correctness check
# ... one entry per test shape
# Benchmark shapes
(1, 768, 512, 4): helion.Config(...), # TODO: replace with your autotuned config
(1, 768, 2048, 4): helion.Config(...), # TODO: replace with your autotuned config
# ... one entry per benchmark shape
}


def _make_kernel(config: helion.Config):
@helion.kernel(static_shapes=True, config=config)
def causal_conv1d_kernel(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
# Your Helion kernel implementation
...

return causal_conv1d_kernel


_KERNELS = {shape: _make_kernel(cfg) for shape, cfg in SHAPE_CONFIGS.items()}


def custom_kernel(data: input_t) -> output_t:
x, weight, bias = data
return causal_conv1d_kernel(x, weight, bias)
B, D, S = x.shape
W = weight.shape[1]
kernel = _KERNELS[(B, D, S, W)]
return kernel(x, weight, bias)
```

## Do NOT Autotune on KernelBot

When submitting to KernelBot, you must hardcode a single config in your `@helion.kernel` decorator. Do **not** rely on Helion's autotuner at submission time.
When submitting to KernelBot, you must hardcode configs in your `@helion.kernel` decorator. Do **not** rely on Helion's autotuner at submission time.

KernelBot runs your submission on shared infrastructure with timeouts. If your kernel triggers autotuning (which can take 10+ minutes and hundreds of trial runs), your submission will time out and fail.

### The correct workflow
### Getting a default config (no autotuning)

During early development, you can use `autotune_effort="none"` to skip autotuning and use Helion's default config. When you run the kernel, Helion prints the default config to stderr:

```
Using default config: @helion.kernel(config=helion.Config(block_sizes=[64, 64], num_warps=4, num_stages=1), static_shapes=True)
```

Copy the `helion.Config(...)` portion into your `SHAPE_CONFIGS` dict. The default config is usually good enough for test input shapes to pass correctness checks, but won't be competitive for benchmark shapes on the leaderboard.

1. **Autotune locally on your Nebius-provided B200 compute.** Run your Helion kernel without a fixed config (or with `autotune_effort="quick"`) to find the best configuration for the benchmark shapes.
### Autotuning for benchmark shapes

2. **Copy the best config** from the autotuner output. Helion prints something like:
1. **Autotune locally on your Nebius-provided B200 compute.** Run your Helion kernel without a fixed config (or with `autotune_effort="quick"`) to find the best configuration for each benchmark shape.

2. **Copy the best config** from the autotuner output. When autotuning completes, Helion prints:
```
One can hardcode the best config and skip autotuning with:
@helion.kernel(config=helion.Config(block_sizes=[64, 64, 64], ...))
```

3. **Hardcode the config in your submission.** Pass it via `config=` in the `@helion.kernel` decorator:
```python
@helion.kernel(config=helion.Config(
block_sizes=[64, 64, 64],
loop_orders=[[0, 1]],
num_warps=8,
num_stages=6,
indexing='block_ptr',
pid_type='flat',
# ... rest of your tuned config
))
def my_kernel(...):
...
@helion.kernel(config=helion.Config(block_sizes=[64, 64, 64], num_warps=8, num_stages=3))
```

4. **Submit the file** with the hardcoded config to KernelBot.
3. **Hardcode the config in your submission.** Copy the `helion.Config(...)` from step 2 into the corresponding benchmark shape entry in `SHAPE_CONFIGS`. Repeat steps 1-3 for each benchmark shape in `task.yml`.

You can also use `autotune_effort="none"` during development to skip autotuning entirely and use the default config, but this will give worse performance.
4. **Submit the file** with the hardcoded configs to KernelBot.

## Submitting All 5 Problems

Expand Down Expand Up @@ -193,22 +239,24 @@ popcorn submit causal_conv1d_py/submission.py --gpu B200_Nebius --leaderboard ca

This returns GPU throughput, pipe utilization, and warp stall metrics, plus a downloadable `.ncu-rep` trace file you can open in the Nsight Compute GUI. See [profiling.md](profiling.md) for details on interpreting the output.

## Using ACF Files (Booster Pack)
## Optional: Extra Performance Knobs

Each B200 instance comes with pre-tuned **PTXAS Advanced Controls Files (ACFs)** at `/opt/booster_pack/`. These are low-level NVIDIA PTX assembler configurations that can improve kernel performance beyond what Helion's standard autotuner finds.
The sections below describe two **optional** techniques that can squeeze extra performance out of your kernels. Neither is required — you can place on the leaderboard without them. Try them after you have a working kernel with a tuned config.

### ACF Files

Each B200 instance has pre-tuned **PTXAS Advanced Controls Files (ACFs)** at `/opt/booster_pack/`. ACFs are low-level NVIDIA PTX assembler configurations that can improve performance beyond what Helion's standard autotuner finds. Available files:

```
/opt/booster_pack/
├── causal_conv_0.acf ... causal_conv_2.acf
├── chunk_fwd_h_0.acf ... chunk_fwd_h_1.acf
├── chunk_fwd_o_0.acf ... chunk_fwd_o_6.acf
├── fp8_group_quant_0.acf ... fp8_group_quant_6.acf
└── recompute_w_u_fwd_0.acf ... recompute_w_u_fwd_4.acf
├── causal_conv_*.acf (3 files)
├── chunk_fwd_h_*.acf (2 files)
├── chunk_fwd_o_*.acf (7 files)
├── fp8_group_quant_*.acf (7 files)
└── recompute_w_u_fwd_*.acf (5 files)
```

### Using ACFs during autotuning

Pass `autotune_search_acf` to the `@helion.kernel` decorator. Helion treats each ACF as another tunable parameter — every config candidate gets tried with each ACF file (plus the default `-O3` baseline):
**Step 1: Autotune with ACFs.** Pass `autotune_search_acf` to include ACFs in the search space. Helion tries each ACF file (plus the default `-O3` baseline) as another tunable parameter:

```python
from pathlib import Path
Expand All @@ -223,11 +271,9 @@ def my_kernel(...):
...
```

> **Important:** `autotune_search_acf` only takes effect when the autotuner actually runs. If you set `autotune_effort="none"` or provide a fixed `config=`, the ACF list is ignored.
> **Note:** `autotune_search_acf` only takes effect when the autotuner actually runs. It is ignored with `autotune_effort="none"` or a fixed `config=`.

### Hardcoding an ACF in your submission

After autotuning finds the best ACF, include it in your hardcoded config via `advanced_controls_file`. The autotuner prints the winning ACF path — copy it into your `Config`:
**Step 2: Hardcode the best ACF in your submission.** After autotuning, look for the `advanced_controls_file` field in the best config and copy it:

```python
@helion.kernel(config=helion.Config(
Expand All @@ -241,22 +287,79 @@ def my_kernel(...):
...
```

This is the approach you should use for KernelBot submissions — a fixed config with a fixed ACF, no autotuning at runtime.
### TileIR Backend

The B200 instances also ship with **nvtriton**, NVIDIA's extended Triton compiler that includes a **TileIR** backend — an alternative compilation pipeline that bypasses LLVM and compiles directly to CUBIN via NVIDIA's `tileiras` compiler.

| | `ENABLE_TILE=0` (default) | `ENABLE_TILE=1` |
|---|---|---|
| **Helion backend** | `triton` | `tileir` |

**Step 1: Enable TileIR and autotune.** Set the env vars before importing Helion, then autotune as usual. Helion automatically adjusts the search space for the TileIR backend.

**Step 2: Hardcode the TileIR config in your submission.** Copy the best config from the autotuner output (it will include TileIR-specific fields like `num_ctas` and `occupancy`). The env vars must be set before imports:

```python
import os
os.environ["ENABLE_TILE"] = "1"
os.environ["HELION_BACKEND"] = "tileir"

import helion # must be imported after setting env vars
import helion.language as hl

@helion.kernel(config=helion.Config(
block_sizes=[64, 64],
num_ctas=1,
num_stages=5,
occupancy=4,
# ... rest of your tuned config
))
def my_kernel(...):
...
```

### Recommended workflow
### Which should I use?

1. **Autotune with ACFs locally** on your B200, using the matching `*_*.acf` files for your problem
2. **Check the best config output** — look for the `advanced_controls_file` field
3. **Hardcode both the config and ACF path** in your submission file
4. **Verify** the ACF path (`/opt/booster_pack/...`) exists on B200 — it does on all hackathon instances
Try both `ENABLE_TILE=0` and `ENABLE_TILE=1`, with and without ACFs, then submit whichever gives the best benchmark numbers.

## Tips

- **Iterate locally first.** Use your Nebius B200 to develop and autotune. Only submit to KernelBot once you have a hardcoded config that works.
- **Check the reference.** Each `reference.py` shows the baseline implementation you're trying to beat. Understanding it helps you write a better kernel.
- **Use `--mode test` first.** Verify correctness before submitting to the leaderboard. This saves time and leaderboard quota.
- **Profile your kernels.** Use `--mode profile` to get Nsight Compute metrics and identify bottlenecks.
- **One config per submission.** If Helion found different best configs for different benchmark shapes, pick the one that works best across all of them -- the leaderboard uses geometric mean across benchmarks.
- **One config per shape.** Use the per-shape config pattern to provide an optimized config for each benchmark shape in `task.yml`.

## Working on Your GPU Machine

- **Use a GitHub repo for your kernels.** Push your work to a private GitHub repo so you don't lose progress if the GPU machine goes offline or loses data.
- **Use tmux for autotuning.** Autotuning can take a long time. Run it inside a `tmux` session so it survives SSH disconnections.
- **Use spawn mode for autotuning if you hit issues.** By default, Helion's autotuner uses `fork` mode for precompilation, which is faster but can hang or crash if a bad config corrupts process state. If that happens, switch to `spawn` mode, which runs each trial in an isolated subprocess with timeout protection — slower due to subprocess overhead, but one bad config can't take down your entire autotuning run. Enable it via environment variable or decorator:
```bash
export HELION_AUTOTUNE_PRECOMPILE=spawn
```
```python
@helion.kernel(autotune_precompile="spawn")
def my_kernel(...):
...
```
You can also control parallelism with `HELION_AUTOTUNE_PRECOMPILE_JOBS` (defaults to CPU count).
- **Machine frozen or crashed?** If your GPU machine becomes unresponsive and needs a reboot, let us know and we can reboot it for you.

## Open-Ended Contribution Track

In addition to the kernel competition, there is a separate open-ended contribution track. Participants can earn recognition and prizes for contributions to Helion beyond kernel implementations. This track is scored independently and does not affect kernel competition standings. Examples:

| Contribution Type | Description |
|---|---|
| Autotuner Improvements | Enhancements to Helion's autotuning system |
| Bug Fixes | Bug fixes in Helion |
| Tooling/Infrastructure | Improvements to debugging, profiling, or developer experience |
| Documentation | Significant documentation contributions |
| Other Novel Contributions | Other impactful contributions at judges' discretion |

Contributions are uncapped and evaluated by a panel of judges based on impact and quality. Prizes for this track are awarded separately from the kernel competition.

## Resources

- [Helion Documentation](https://helionlang.com)
Expand Down
Loading