Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/flydsl/compiler/backends/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import List, Tuple

from ...runtime.device import get_rocm_arch, is_rdna_arch
from ...runtime.device import get_rocm_arch, get_rocm_toolkit_path, is_rdna_arch
from ...utils import env
from .base import BaseBackend, GPUTarget

Expand Down Expand Up @@ -90,7 +90,9 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]:
else []
),
]
binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"}}'
toolkit_path = get_rocm_toolkit_path() or ""
toolkit_opt = f" toolkit={toolkit_path}" if toolkit_path else ""
binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"{toolkit_opt}}}'
return [*pre_binary_fragments, *binary_prep_fragments], binary_fragment

def pipeline_fragments(self, *, compile_hints: dict) -> List[str]:
Expand Down
5 changes: 4 additions & 1 deletion python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,11 @@ def _dump_isa(*, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool, stage_
di_pass = (
"ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}," if env.debug.enable_debug_info else ""
)
from ..runtime.device import get_rocm_toolkit_path

toolkit_path = get_rocm_toolkit_path() or ""
pm = PassManager.parse(
f'builtin.module({di_pass}gpu-module-to-binary{{format=isa opts="{"-g" if env.debug.enable_debug_info else ""}" section= toolkit=}})',
f'builtin.module({di_pass}gpu-module-to-binary{{format=isa opts="{"-g" if env.debug.enable_debug_info else ""}" section= toolkit={toolkit_path}}})',
context=ctx,
)
pm.enable_verifier(bool(verify))
Expand Down
59 changes: 59 additions & 0 deletions python/flydsl/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,70 @@
import functools
import os
import subprocess
from pathlib import Path
from typing import Optional

_ROCM_AGENT_TIMEOUT_S = int(os.environ.get("FLYDSL_ROCM_AGENT_TIMEOUT", "300"))


@functools.lru_cache(maxsize=None)
def get_rocm_toolkit_path() -> Optional[str]:
"""Return a directory MLIR's ROCDL backend recognizes as a toolkit.

MLIR's gpu-module-to-binary expects ``<toolkit>/llvm/bin/ld.lld`` for
linking and ``<toolkit>/amdgcn/bitcode`` for device libraries. The
rocm-sdk Python wheels (``_rocm_sdk_core``) ship both, but at
``<sdk>/lib/llvm/bin/ld.lld`` and ``<sdk>/lib/llvm/amdgcn/bitcode``, so
the layout doesn't directly match. We synthesize a tiny symlink-based
shim under ``~/.flydsl/toolkit`` and return its path.

Order of preference:
1. ``FLYDSL_ROCM_TOOLKIT_PATH`` env var (explicit override)
2. ``ROCM_PATH`` env var
3. ``/opt/rocm`` if present and well-formed
4. Synthesized shim pointing at the rocm-sdk Python wheel.
Returns ``None`` if no toolkit can be located.
"""

def _well_formed(root: Path) -> bool:
return (root / "llvm" / "bin" / "ld.lld").exists() and (root / "amdgcn" / "bitcode").is_dir()

for env_var in ("FLYDSL_ROCM_TOOLKIT_PATH", "ROCM_PATH"):
val = os.environ.get(env_var, "").strip()
if val and _well_formed(Path(val)):
return val

opt_rocm = Path("/opt/rocm")
if _well_formed(opt_rocm):
return str(opt_rocm)

try:
import _rocm_sdk_core # type: ignore[import-not-found]
except ImportError:
return None

sdk_root = Path(_rocm_sdk_core.__file__).parent
llvm_dir = sdk_root / "lib" / "llvm"
if not (llvm_dir / "bin" / "ld.lld").exists() or not (llvm_dir / "amdgcn" / "bitcode").is_dir():
return None

shim_root = Path(os.environ.get("FLYDSL_ROCM_TOOLKIT_SHIM_DIR") or (Path.home() / ".flydsl" / "toolkit"))
shim_root.mkdir(parents=True, exist_ok=True)
(shim_root / "llvm" / "bin").mkdir(parents=True, exist_ok=True)
amdgcn_link = shim_root / "amdgcn"
if not amdgcn_link.exists():
amdgcn_link.symlink_to(llvm_dir / "amdgcn")
# ``ld.lld`` in the rocm-sdk wheel is a tiny stub that needs to resolve
# its own argv[0] to load companion libraries. Copying it elsewhere
# breaks that lookup, so we drop a thin exec wrapper instead.
wrapper = shim_root / "llvm" / "bin" / "ld.lld"
wrapper_text = f'#!/bin/bash\nexec "{llvm_dir}/bin/ld.lld" "$@"\n'
if not wrapper.exists() or wrapper.read_text() != wrapper_text:
wrapper.write_text(wrapper_text)
wrapper.chmod(0o755)
return str(shim_root)


def _arch_from_rocm_agent_enumerator() -> Optional[str]:
"""Query rocm_agent_enumerator (standard ROCm tool) for the first GPU arch."""
try:
Expand Down
Loading