Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions src/xe_forge/core/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def _validate_triton(self, code: str, stage: str | None = None) -> list[Validati

# -- Xe-Forge structural checks (from optimizer_agent inline) --

# Syntax parse
# Parse AST once for reuse
try:
ast.parse(code)
tree = ast.parse(code)
except SyntaxError as e:
issues.append(
ValidationIssue(
Expand Down Expand Up @@ -201,29 +201,46 @@ def _validate_triton(self, code: str, stage: str | None = None) -> list[Validati

# 2. Grid dimensionality with swizzling
has_swizzling = "GROUP_SIZE_M" in code or "swizzle" in code.lower()
for i, line in enumerate(lines):
if "grid" in line and "=" in line:
if line.count("triton.cdiv") >= 2 and "(" in line and ")" in line:
paren_depth = 0
comma_count = 0
start = line.index("(")
for ch in line[start:]:
if ch == "(":
paren_depth += 1
elif ch == ")":
paren_depth -= 1
elif ch == "," and paren_depth == 1:
comma_count += 1
if comma_count >= 1 and has_swizzling:
issues.append(
ValidationIssue(
"grid_swizzle_conflict",
"error",
"Grid is 2D but tile swizzling (GROUP_SIZE_M) is used. "
"Grid must be 1D with swizzling.",
line=i + 1,
)
if has_swizzling:
function_defs = {
node.name: node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)
}

for node in ast.walk(tree):
if not isinstance(node, ast.Assign):
continue

target_names = [
target.id for target in node.targets if isinstance(target, ast.Name)
]
if not any(name.startswith("grid") for name in target_names):
continue

grid_expr = node.value
is_2d_grid = False

if isinstance(grid_expr, ast.Tuple):
is_2d_grid = len(grid_expr.elts) > 1
elif isinstance(grid_expr, ast.Lambda) and isinstance(grid_expr.body, ast.Tuple):
is_2d_grid = len(grid_expr.body.elts) > 1
elif isinstance(grid_expr, ast.Name) and grid_expr.id in function_defs:
grid_func = function_defs[grid_expr.id]
for stmt in grid_func.body:
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Tuple):
if len(stmt.value.elts) > 1:
is_2d_grid = True
break

if is_2d_grid:
issues.append(
ValidationIssue(
"grid_swizzle_conflict",
"error",
"Grid is 2D but tile swizzling (GROUP_SIZE_M) is used. "
"Grid must be 1D with swizzling.",
line=getattr(node, "lineno", None),
)
)

# 3. boundary_check format
for i, line in enumerate(lines):
Expand Down
82 changes: 82 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for static kernel validation."""

from xe_forge.core.validator import KernelValidator

VALID_1D_SWIZZLED_GRID = """\
import triton
import triton.language as tl

GROUP_SIZE_M = 4


@triton.jit
def kernel():
pass


class Model:
pass


def launch():
grid = lambda META: (triton.cdiv(M, META["BM"]) * triton.cdiv(N, META["BN"]),)
kernel[grid]()
"""


INVALID_2D_SWIZZLED_GRID = """\
import triton
import triton.language as tl

GROUP_SIZE_M = 4


@triton.jit
def kernel():
pass


class Model:
pass


def launch():
grid = lambda META: (triton.cdiv(M, META["BM"]), triton.cdiv(N, META["BN"]))
kernel[grid]()
"""
Comment thread
mzweilin marked this conversation as resolved.


INVALID_2D_TUPLE_SWIZZLED_GRID = """\
import triton
import triton.language as tl

GROUP_SIZE_M = 4


@triton.jit
def kernel():
pass


class Model:
pass


def launch():
grid = (triton.cdiv(M, 128), triton.cdiv(N, 256))
kernel[grid]()
"""


class TestGridSwizzleValidation:
def test_1d_grid_with_swizzle_is_allowed(self):
issues = KernelValidator().validate(VALID_1D_SWIZZLED_GRID, dsl="triton")
assert all(issue.check_name != "grid_swizzle_conflict" for issue in issues)

def test_2d_grid_with_swizzle_is_rejected(self):
issues = KernelValidator().validate(INVALID_2D_SWIZZLED_GRID, dsl="triton")
assert any(issue.check_name == "grid_swizzle_conflict" for issue in issues)

def test_2d_tuple_grid_with_swizzle_is_rejected(self):
issues = KernelValidator().validate(INVALID_2D_TUPLE_SWIZZLED_GRID, dsl="triton")
assert any(issue.check_name == "grid_swizzle_conflict" for issue in issues)
Loading