Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ xe-forge -i kernel.py -s spec.yaml --target-dtype float16
# Use a different LLM model
xe-forge -i kernel.py -s spec.yaml --model openai/gpt-4-turbo

# Allow dtype_fix to produce slower-but-correct kernels
xe-forge -i kernel.py -s spec.yaml --correctness-only-stages dtype_fix

# Multiple candidates (pick best)
xe-forge -i kernel.py -s spec.yaml --best-k 3

Expand Down Expand Up @@ -359,6 +362,14 @@ xe-forge -i kernel.py -s spec.yaml \
--stages algorithmic,dtype_fix,fusion,memory_access,persistent_kernel,xpu_specific,autotuning
```

### Correctness-Only Stages

Some stages (e.g. `dtype_fix`, `algorithmic`) may produce kernels that are temporarily slower — a cast from fp16 to fp32 is required for correctness but adds overhead. Use `--correctness-only-stages` to let these stages skip the speedup regression check; later stages (e.g. `fusion`, `xpu_specific`) are expected to recover the performance.

```bash
xe-forge -i kernel.py -s spec.yaml --correctness-only-stages dtype_fix,algorithmic
```

---

## CLI Reference
Expand Down Expand Up @@ -386,6 +397,7 @@ xe-forge --input KERNEL --spec SPEC [OPTIONS]
| Flag | Description |
|------|-------------|
| `--stages` | Comma-separated stages (e.g. `dtype_fix,xpu_specific`) |
| `--correctness-only-stages` | Comma-separated stages that skip the speedup regression check (e.g. `dtype_fix,algorithmic`) |
| `--target-dtype` | Target dtype: `float16`, `bfloat16`, `float32` |

### LLM Configuration
Expand Down
1 change: 1 addition & 0 deletions src/xe_forge/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def optimize_stage(
dtype=None,
pytorch_code: str = "",
init_args: list | None = None,
correctness_only_stages: set | None = None,
) -> StageResult:
"""Apply optimization stage to kernel code."""
...
4 changes: 2 additions & 2 deletions src/xe_forge/agents/optimizer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def optimize_stage(
init_args=None,
vtune_report="",
perf_context: dict | None = None,
correctness_only_stages: set | None = None,
):
logger.info(f"Applying optimization stage: {stage.value}")
original_code = code
Expand Down Expand Up @@ -507,8 +508,7 @@ def optimize_stage(
f" {k}: {v}" for k, v in xpu_config.items()
)

CORRECTNESS_ONLY_STAGES = {}
skip_speedup = stage in CORRECTNESS_ONLY_STAGES
skip_speedup = stage in (correctness_only_stages or set())

_baseline_ms = perf_context.get("original_ms") if perf_context else None

Expand Down
9 changes: 7 additions & 2 deletions src/xe_forge/agents/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _create_verify_tool(
input_shapes: list[tuple[int, ...]] | None,
flop: float | None,
dtype=None,
skip_speedup_check: bool = False,
) -> Callable:
"""
Create a verification tool for ReAct.
Expand Down Expand Up @@ -251,7 +252,7 @@ def compile_and_verify(optimized_code: dspy.Code["python"]) -> str: # noqa: UP0
)

# Check performance regression
if comparison.is_slower:
if not skip_speedup_check and comparison.is_slower:
slowdown = (
1.0 / comparison.speedup if comparison.speedup > 0 else float("inf")
)
Expand Down Expand Up @@ -309,6 +310,7 @@ def optimize_stage(
pytorch_code: str = "",
init_args: list | None = None,
perf_context: dict | None = None,
correctness_only_stages: set | None = None,
) -> StageResult:
"""
Apply a single optimization stage using ReAct.
Expand Down Expand Up @@ -357,13 +359,16 @@ def optimize_stage(
[f" {k}: {v}" for k, v in xpu_config.items()]
)

skip_speedup = stage in (correctness_only_stages or set())

# Create verification tool
verify_tool = self._create_verify_tool(
original_code=original_code,
kernel_name=kernel_name,
input_shapes=input_shapes,
flop=flop,
dtype=dtype,
skip_speedup_check=skip_speedup,
)

# Create ReAct agent for this optimization
Expand Down Expand Up @@ -427,7 +432,7 @@ def optimize_stage(

if not comparison.optimized_correct:
last_error = "Optimized kernel produces incorrect results"
elif comparison.is_slower:
elif not skip_speedup and comparison.is_slower:
slowdown = (
1.0 / comparison.speedup if comparison.speedup > 0 else float("inf")
)
Expand Down
20 changes: 20 additions & 0 deletions src/xe_forge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def main():
type=str,
help="Comma-separated stages to apply (e.g., dtype_fix,fusion,xpu_specific)",
)
parser.add_argument(
"--correctness-only-stages",
type=str,
help="Comma-separated stages that skip the speedup regression check "
"(only require correctness). e.g., dtype_fix,algorithmic",
)

# LLM configuration
parser.add_argument("--model", type=str, help="LLM model to use")
Expand Down Expand Up @@ -173,6 +179,17 @@ def main():
except ValueError:
print(f"Warning: Unknown stage '{name}', skipping")

# Parse correctness-only stages
correctness_only_stages = None
if args.correctness_only_stages:
cos_names = [s.strip() for s in args.correctness_only_stages.split(",")]
correctness_only_stages = set()
for name in cos_names:
try:
correctness_only_stages.add(OptimizationStage(name))
except ValueError:
print(f"Warning: Unknown stage '{name}' in --correctness-only-stages, skipping")

# Print header
print("=" * 60)
print("TRITON OPTIMIZER")
Expand All @@ -186,6 +203,8 @@ def main():
if args.target_dtype:
print(f"Target dtype: {args.target_dtype}")
print(f"Stages: {[s.value for s in stages] if stages else 'all'}")
if correctness_only_stages:
print(f"Correctness-only stages: {[s.value for s in correctness_only_stages]}")
print(f"Best@k: {config.optimization.best_k}")

# Print correctness settings
Expand Down Expand Up @@ -267,6 +286,7 @@ def main():
target_dtype=args.target_dtype,
rtol=args.rtol,
atol=args.atol,
correctness_only_stages=correctness_only_stages,
)

# Save output if requested
Expand Down
2 changes: 2 additions & 0 deletions src/xe_forge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def optimize(
target_dtype=None,
rtol=None,
atol=None,
correctness_only_stages=None,
):
import torch

Expand Down Expand Up @@ -313,6 +314,7 @@ def optimize(
else None
),
},
correctness_only_stages=correctness_only_stages,
)
result.stages_applied.append(stage_result)

Expand Down
Loading