Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,47 @@ We then extend `KernelBenchEnv` to support:
- **Batching**: `KernelBenchEnvGroupBuilder` groups multiple rollouts for the same problem, enabling **GRPO-style** training where rewards are normalized within groups.
- **Dataset Construction**: `KernelBenchDatasetBuilder` handles the iteration over KernelBench levels and problems, partitioning them into training and evaluation sets. You are welcome to extend it to support more problems beyond what is currently in KernelBench.

### Multi-Turn RL

We extend the single-turn pipeline with multi-turn iterative refinement, following the approach in [Kevin](https://arxiv.org/abs/2507.11948). Instead of generating one kernel per problem, the model generates a kernel, receives evaluation feedback (compilation errors, correctness failures, or speedup results), and refines its solution over multiple turns.

`MultiTurnKernelBenchEnv` manages the multi-turn loop:
- **History management**: Prior turns (prompt, response, feedback) are kept in context with token-based truncation to stay within the context window.
- **Evaluation feedback**: Structured feedback tells the model what went wrong (compilation error, incorrect output, or correct but slow) so it can fix specific issues.
- **Early stopping**: Optionally stop the episode when the kernel passes all correctness tests.

Training uses GRPO with discounted returns across turns:
- Per-turn scores are computed as `S = 0.3 * correct + speedup` (only for correct kernels).
- Discounted returns: `R_t = S_t + γ * R_{t+1}` (backward recursion, γ=0.4 by default).
- Advantages are normalized across all `group_size × max_turns` turn-level samples: `(R - mean) / (std + ε)`.
- PPO with asymmetric clipping (Clip-Higher, ε_low=0.2, ε_high=0.28) and constant length normalization.

Enable multi-turn via config:
```yaml
multiturn:
enabled: true
max_turns: 4 # Refinement turns per trajectory
gamma: 0.4 # Discount factor
aggregation: "sum" # "sum" or "max"
```

Or via CLI:
```bash
uv run python -m kernelbench_tinker.scripts.train_kernel_rl \
--config src/kernelbench_tinker/config/rl_kernelbench.yaml \
multiturn.enabled=true \
log_path=./runs/my_multiturn_experiment
```

Multi-turn inference is also supported via the eval script:
```bash
uv run python -m kernelbench_tinker.scripts.eval_kernel_rl \
checkpoint_path=<your_checkpoint> \
multiturn_enabled=true \
multiturn_max_turns=8 \
level=1
```


### Directory Structure
```text
Expand All @@ -54,6 +95,7 @@ src/kernelbench_tinker/
envs/
kernelbench_client.py # KernelBench Python API wrapper
kernelbench_env.py # Single-turn RL environment
multiturn_kernelbench_env.py # Multi-turn RL environment
training/
models.py # Model/renderer configuration
reward.py # Reward shaping
Expand Down Expand Up @@ -282,7 +324,6 @@ Note the scope of this repo is an open-source implementation of KernelBench-Tink

* More reward examples leveraging more fine-grained metrics
* More reward hack checking
* Multi-turn RL to have denser reward signal like [Kevin](https://arxiv.org/abs/2507.11948)
* Improve Step time and training efficiency


Expand Down
54 changes: 54 additions & 0 deletions src/kernelbench_tinker/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,57 @@ class DatasetConfig:

# Train/test split
test_fraction: float = 0.1


@dataclass
class MultiTurnConfig:
"""
Configuration for multi-turn RL training.

Controls the iterative refinement loop where the model receives
evaluation feedback and can fix errors across multiple turns.
"""

# Enable multi-turn mode (False = single-turn)
enabled: bool = False

# Maximum refinement turns per trajectory
max_turns: int = 4

# Discount factor for multi-turn returns: R_t = S_t + gamma * R_{t+1}
gamma: float = 0.4

# Return aggregation mode: "sum" or "max"
# sum: R_t = Σ γ^(i-t) × S_i (reward turns leading to many good kernels)
# max: R_t = max{ γ^(i-t) × S_i } (reward turns leading to one great kernel)
aggregation: str = "sum"

# Stop the episode early when the kernel is correct.
# Default False for training: model needs post-correctness turns to
# learn speedup optimization. Set True at eval time if desired.
early_stop_on_correct: bool = False

# Optional: require this speedup before early stopping
speedup_threshold: float | None = None

# Prompt
prompt_max_tokens: int | None = None # Token budget for history truncation (None = char fallback)
inject_think_token: bool = False # Append <think>\n to generation prompts

# Generation
temperature: float = 0.9
top_p: float = 1.0
seed: int | None = None

# Response length extension mid-training (0 = disabled)
max_tokens_extended: int = 22000
max_tokens_extend_after_step: int = 30

# Training
loss_fn: str = "ppo"
max_grad_norm: float = 0.05
warmup_ratio: float = 0.03
clip_epsilon_low: float = 0.2
clip_epsilon_high: float = 0.28
constant_length_norm: int = 16384
num_substeps: int = 2
31 changes: 31 additions & 0 deletions src/kernelbench_tinker/config/rl_kernelbench.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@ learning_rate: 0.000002 # 2e-6 as explicit float
max_tokens: 16384
temperature: 1.0

# =============================================================================
# Multi-turn Configuration (disabled by default)
# =============================================================================
multiturn:
enabled: false # true to enable iterative refinement
max_turns: 4 # Maximum refinement turns per trajectory
gamma: 0.4 # Discount factor for multi-turn returns
aggregation: "sum" # "sum" (reward many good kernels) or "max" (reward one great kernel)
early_stop_on_correct: false # Stop episode when kernel passes all tests
speedup_threshold: null # Required speedup before early stopping (null = any correct)
# Prompt
prompt_max_tokens: null # Token budget for history truncation (null = char fallback)
inject_think_token: false # Append <think>\n to generation prompts
# Generation
temperature: 0.9 # Generation temperature
top_p: 1.0 # Nucleus sampling (1.0 = disabled)
seed: null # Random seed for generation (null = random)
max_tokens_extended: 22000 # Extend max_tokens mid-training (0 = disabled)
max_tokens_extend_after_step: 30 # Step at which to switch
# Training
loss_fn: "ppo" # Loss function (single-turn uses top-level loss_fn)
max_grad_norm: 0.05 # Gradient clipping (0.0 = disabled)
warmup_ratio: 0.03 # Linear LR warmup fraction
clip_epsilon_low: 0.2 # PPO clip lower bound
clip_epsilon_high: 0.28 # PPO clip upper bound (Clip-High)
constant_length_norm: 16384 # GRPO constant length normalization (0 = disabled)

# =============================================================================
# Training Configuration
# =============================================================================
Expand Down Expand Up @@ -57,6 +84,7 @@ dataset_builder:
# Problem Selection
# ---------------------------------------------------------------------------
level: 1 # KernelBench level (1, 2, 3, or 4)
levels: null # Train on multiple levels (e.g. [1, 2]); overrides level when set
start_problem: null # First problem ID (null = start from 1)
end_problem: null # Last problem ID (null = all problems)
dataset_src: "huggingface" # "huggingface" or "local"
Expand Down Expand Up @@ -107,6 +135,9 @@ dataset_builder:
reward_correctness_weight: 0.3
reward_speed_weight: 1.0
reward_length_weight: 0.0
reward_speed_max_reward: 10.0 # Cap on speed reward component (set high to uncap)
reward_clip_min: null # Lower bound on total reward (null = no clipping)
reward_clip_max: null # Upper bound on total reward (null = no clipping)

# ---------------------------------------------------------------------------
# Reward Hacking Detection (Static Checker)
Expand Down
150 changes: 150 additions & 0 deletions src/kernelbench_tinker/envs/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Shared utilities for KernelBench environments.

Contains helpers used by both the single-turn and multi-turn environments:
- System prompt construction
- Step evaluation (parse → evaluate → reward → metrics)
"""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING

from tinker_cookbook import renderers
from tinker_cookbook.rl.types import Action, Metrics

from kernelbench_tinker.config.configs import EvalConfig
from kernelbench_tinker.envs.kernelbench_client import (
KernelBenchProblem,
KernelEvalResult,
ParsedResponse,
evaluate_kernel_async,
parse_structured_response,
)
from kernelbench_tinker.training.reward import (
RewardConfig,
compute_reward,
)

logger = logging.getLogger(__name__)


@dataclass
class EvalStepResult:
"""Result from evaluate_step(), shared by single-turn and multi-turn envs."""

parsed: ParsedResponse
eval_result: KernelEvalResult
format_ok: bool
kernel_code: str
reward: float
metrics: Metrics
response_text: str # Raw response content from renderer (before structured parsing)


def build_system_prompt(backend: str) -> str:
"""Build a backend-specific system prompt for kernel generation.

Used by both single-turn and multi-turn environments.
"""
return (
f"You are an expert GPU kernel developer. Your task is to optimize PyTorch "
f"operations by writing efficient custom {backend.upper()} kernels.\n"
f"\n"
f"When given a PyTorch model, write an optimized kernel implementation.\n"
f"\n"
f"Your solution must:\n"
f"- Be a drop-in replacement as a class named `ModelNew`\n"
f"- Use custom {backend.upper()} kernels, not just PyTorch operations\n"
f"- Be correct and produce the same results as the reference\n"
f"\n"
f"You MUST respond in exactly this format:\n"
f"\n"
f"<KERNEL>\n"
f"```python\n"
f"# Your complete optimized implementation here\n"
f"class ModelNew(nn.Module):\n"
f" ...\n"
f"```\n"
f"</KERNEL>"
)


async def evaluate_step(
problem: KernelBenchProblem,
renderer: renderers.Renderer,
action: Action,
eval_config: EvalConfig,
reward_config: RewardConfig,
step_start: float,
) -> EvalStepResult:
"""Parse, evaluate, and compute reward for a single action.

Shared by KernelBenchEnv.step() and MultiTurnKernelBenchEnv.step().
"""
message, _ = renderer.parse_response(action)
response_text = message.get("content", "")

parsed = parse_structured_response(response_text)
kernel_code = parsed.kernel
format_ok = parsed.format_ok

eval_start = time.perf_counter()
cfg = eval_config
eval_result = await evaluate_kernel_async(
level=problem.level,
problem_id=problem.problem_id,
backend=problem.backend,
kernel_code=kernel_code,
dataset_src=problem.dataset_src,
num_correct_trials=cfg.num_correct_trials,
measure_performance=cfg.measure_performance,
num_perf_trials=cfg.num_perf_trials,
timing_method=cfg.timing_method,
precision=cfg.precision,
check_for_excessive_speedup=cfg.check_for_excessive_speedup,
excessive_speedup_threshold=cfg.excessive_speedup_threshold,
timeout=cfg.modal_timeout,
)
eval_time = time.perf_counter() - eval_start

reward = compute_reward(
eval_result,
reward_config,
kernel_code=kernel_code,
backend=problem.backend,
)

metrics: Metrics = {
"level": problem.level,
"problem_id": problem.problem_id,
"format_ok": float(format_ok),
"compiled": float(eval_result["compiled"]),
"correctness": float(eval_result["correctness"]),
"tests_passed": eval_result["tests_passed"],
"tests_total": eval_result["tests_total"],
}
if eval_result.get("speedup") is not None:
metrics["speedup"] = eval_result["speedup"]
if eval_result.get("runtime_ms") is not None:
metrics["runtime_ms"] = eval_result["runtime_ms"]
metrics["time/eval"] = eval_time
timing_metadata = (eval_result.get("metadata") or {}).get("timings", {})
if "reference_load_s" in timing_metadata:
metrics["time/ref_load"] = timing_metadata["reference_load_s"]
if "modal_eval_s" in timing_metadata:
metrics["time/modal_eval"] = timing_metadata["modal_eval_s"]
metrics["time/step_total"] = time.perf_counter() - step_start

return EvalStepResult(
parsed=parsed,
eval_result=eval_result,
format_ok=format_ok,
kernel_code=kernel_code,
reward=reward,
metrics=metrics,
response_text=response_text,
)
Loading