Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions lean_py/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Data model
# ============================================================================


@dataclass
class StructField:
name: str
Expand Down Expand Up @@ -64,6 +65,7 @@ class HeaderModel:
# Header location
# ============================================================================


def find_lean_header() -> Path:
"""Locate lean.h via the active Lean toolchain.

Expand All @@ -76,9 +78,7 @@ def find_lean_header() -> Path:
`lean-toolchain` file, for callers who don't have `lean` on PATH.
"""
try:
prefix = subprocess.check_output(
["lean", "--print-prefix"], text=True
).strip()
prefix = subprocess.check_output(["lean", "--print-prefix"], text=True).strip()
header = Path(prefix) / "include" / "lean" / "lean.h"
if header.exists():
return header
Expand All @@ -89,7 +89,15 @@ def find_lean_header() -> Path:
if toolchain_file.exists():
toolchain = toolchain_file.read_text().strip()
toolchain_dir = toolchain.replace("/", "--").replace(":", "---")
header = Path.home() / ".elan" / "toolchains" / toolchain_dir / "include" / "lean" / "lean.h"
header = (
Path.home()
/ ".elan"
/ "toolchains"
/ toolchain_dir
/ "include"
/ "lean"
/ "lean.h"
)
if header.exists():
return header

Expand All @@ -102,6 +110,7 @@ def find_lean_header() -> Path:
# Extraction
# ============================================================================


def extract_defines(header_path: Path) -> dict[str, int]:
"""Extract integer #define constants from raw header."""
text = header_path.read_text()
Expand All @@ -122,7 +131,8 @@ def _preprocess(header_path: Path) -> Path:
# -D flags strip GCC/Clang constructs that pycparser can't handle
result = subprocess.run(
[
"cc", "-E",
"cc",
"-E",
"-D__STDC_VERSION__=201112L",
"-DNDEBUG",
"-D__attribute__(x)=",
Expand All @@ -143,7 +153,8 @@ def _preprocess(header_path: Path) -> Path:
f"-I{include_dir}",
str(header_path),
],
capture_output=True, text=True,
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"C preprocessor failed: {result.stderr}")
Expand Down Expand Up @@ -182,6 +193,7 @@ def _preprocess(header_path: Path) -> Path:
# AST helpers
# ============================================================================


def _type_to_str(node) -> str:
if node is None:
return "void"
Expand Down Expand Up @@ -223,7 +235,9 @@ def _extract_struct(node: c_ast.Struct) -> StructDef | None:
bitfield = None
if decl.bitsize and isinstance(decl.bitsize, c_ast.Constant):
bitfield = int(decl.bitsize.value)
fields.append(StructField(name=fname, c_type=ftype, is_pointer=is_ptr, bitfield=bitfield))
fields.append(
StructField(name=fname, c_type=ftype, is_pointer=is_ptr, bitfield=bitfield)
)
return StructDef(name=node.name or "", fields=fields)


Expand Down Expand Up @@ -258,15 +272,20 @@ def _extract_export_names(header_path: Path) -> set[str]:

def _extract_inline_names(header_path: Path) -> set[str]:
text = header_path.read_text()
pattern = re.compile(r"static\s+inline\s+(?:LEAN_ALWAYS_INLINE\s+)?[\w\s*]+\s+(\w+)\s*\(")
pattern = re.compile(
r"static\s+inline\s+(?:LEAN_ALWAYS_INLINE\s+)?[\w\s*]+\s+(\w+)\s*\("
)
return {m.group(1) for m in pattern.finditer(text)}


# ============================================================================
# Classification
# ============================================================================

def _classify(ast: c_ast.FileAST, defines: dict[str, int], header_path: Path) -> HeaderModel:

def _classify(
ast: c_ast.FileAST, defines: dict[str, int], header_path: Path
) -> HeaderModel:
model = HeaderModel()
model.constants = defines

Expand All @@ -275,22 +294,31 @@ def _classify(ast: c_ast.FileAST, defines: dict[str, int], header_path: Path) ->

for node in ast.ext:
if isinstance(node, c_ast.Typedef):
if isinstance(node.type, c_ast.TypeDecl) and isinstance(node.type.type, c_ast.Struct):
if isinstance(node.type, c_ast.TypeDecl) and isinstance(
node.type.type, c_ast.Struct
):
struct = _extract_struct(node.type.type)
if struct:
struct.name = node.name
model.structs.append(struct)
continue
typedef_type = _decl_type_to_str(node.type)
model.typedefs.append(TypedefDef(name=node.name, underlying_type=typedef_type))
model.typedefs.append(
TypedefDef(name=node.name, underlying_type=typedef_type)
)

elif isinstance(node, c_ast.Decl):
if isinstance(node.type, c_ast.FuncDecl):
func_decl = node.type
ret_type = _decl_type_to_str(func_decl.type)
params, is_variadic = _extract_func_params(func_decl)
fname = node.name or ""
func = FuncDecl(name=fname, return_type=ret_type, params=params, is_variadic=is_variadic)
func = FuncDecl(
name=fname,
return_type=ret_type,
params=params,
is_variadic=is_variadic,
)
if fname in export_names:
model.exported_functions.append(func)
elif fname in inline_names:
Expand All @@ -303,7 +331,12 @@ def _classify(ast: c_ast.FileAST, defines: dict[str, int], header_path: Path) ->
ret_type = _decl_type_to_str(func_decl.type)
params, is_variadic = _extract_func_params(func_decl)
fname = decl.name or ""
func = FuncDecl(name=fname, return_type=ret_type, params=params, is_variadic=is_variadic)
func = FuncDecl(
name=fname,
return_type=ret_type,
params=params,
is_variadic=is_variadic,
)
if fname in inline_names:
model.inline_functions.append(func)
elif fname in export_names:
Expand Down
91 changes: 68 additions & 23 deletions lean_py/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,31 @@
import ctypes
import functools
from ctypes import (
POINTER, Structure, c_bool, c_char, c_char_p, c_double, c_float,
c_int, c_int8, c_int16, c_int32, c_int64, c_long, c_size_t,
c_ssize_t, c_uint, c_uint8, c_uint16, c_uint32, c_uint64, c_void_p,
POINTER,
Structure,
c_bool,
c_char,
c_char_p,
c_double,
c_float,
c_int,
c_int8,
c_int16,
c_int32,
c_int64,
c_long,
c_size_t,
c_ssize_t,
c_uint,
c_uint8,
c_uint16,
c_uint32,
c_uint64,
c_void_p,
)
from typing import Any

from lean_py._parse import HeaderModel, StructDef, FuncDecl, get_header_model
from lean_py._parse import HeaderModel, StructDef, get_header_model
from lean_py.utils import all_lean_runtime_libs, find_lean_dynlib


Expand Down Expand Up @@ -82,17 +100,27 @@ def _resolve_type(c_type: str, structs: dict[str, type]) -> Any:
return _TYPE_MAP[c_type]

# lean_object pointer types
if c_type in ("lean_object *", "lean_obj_arg", "b_lean_obj_arg",
"u_lean_obj_arg", "lean_obj_res", "b_lean_obj_res"):
if c_type in (
"lean_object *",
"lean_obj_arg",
"b_lean_obj_arg",
"u_lean_obj_arg",
"lean_obj_res",
"b_lean_obj_res",
):
return structs.get("_LeanObjectPtr", c_void_p)

if c_type == "lean_object * *":
obj_ptr = structs.get("_LeanObjectPtr", c_void_p)
return POINTER(obj_ptr)

# Known opaque pointer types
if c_type in ("lean_external_class *", "lean_external_finalize_proc",
"lean_external_foreach_proc", "lean_task_imp *"):
if c_type in (
"lean_external_class *",
"lean_external_finalize_proc",
"lean_external_foreach_proc",
"lean_task_imp *",
):
return c_void_p

# Pointer types
Expand Down Expand Up @@ -121,6 +149,7 @@ def _resolve_type(c_type: str, structs: dict[str, type]) -> Any:
# Dynamic struct creation
# ============================================================================


def _build_structs(model: HeaderModel) -> dict[str, type]:
"""Dynamically create ctypes Structure classes from the model."""
structs: dict[str, type] = {}
Expand Down Expand Up @@ -168,9 +197,9 @@ def _make_struct(sdef: StructDef, known: dict[str, type]) -> type:
# Dynamic FFI class creation
# ============================================================================


def _build_ffi_class(model: HeaderModel, structs: dict[str, type]) -> type:
"""Dynamically create the LeanFFI class with all bindings."""
LeanObjectPtr = structs["_LeanObjectPtr"]
constants = model.constants

def __init__(self):
Expand Down Expand Up @@ -216,8 +245,9 @@ def __init__(self):
# init code can run with the flag still true (which is required
# by Lean's `initialize` blocks), then flip it after.
try:
self.lean_io_mark_end_initialization = \
self.lean_io_mark_end_initialization = (
self.lib.lean_io_mark_end_initialization
)
self.lean_io_mark_end_initialization.argtypes = []
self.lean_io_mark_end_initialization.restype = None
except AttributeError:
Expand Down Expand Up @@ -261,7 +291,9 @@ def _bind_exported(self, lib):
try:
cfunc = getattr(lib, func.name)
if func.params:
cfunc.argtypes = [_resolve_type(p.c_type, structs) for p in func.params]
cfunc.argtypes = [
_resolve_type(p.c_type, structs) for p in func.params
]
restype = _resolve_type(func.return_type, structs)
if restype is not None:
cfunc.restype = restype
Expand Down Expand Up @@ -446,7 +478,9 @@ def lean_ctor_get(self, o, i):
# pointer that aliases into the Lean ctor's m_objs memory. If the
# ctor is later freed (lean_dec), an aliased pointer would become
# stale when the Lean allocator reuses the memory.
elem_addr = ctypes.addressof(ctor.contents) + offset + i * ctypes.sizeof(LeanObjectPtr)
elem_addr = (
ctypes.addressof(ctor.contents) + offset + i * ctypes.sizeof(LeanObjectPtr)
)
raw_val = c_void_p.from_address(elem_addr).value or 0
return ctypes.cast(c_void_p(raw_val), LeanObjectPtr)

Expand Down Expand Up @@ -538,7 +572,9 @@ def lean_alloc_ctor(self, tag, num_objs, scalar_sz):
def lean_alloc_array(self, size, capacity):
fn = self._find_leanpy_helper("leanpy_alloc_array")
if fn is None:
raise RuntimeError("leanpy_alloc_array not found — leanpy_native not linked")
raise RuntimeError(
"leanpy_alloc_array not found — leanpy_native not linked"
)
fn.argtypes = [c_size_t, c_size_t]
fn.restype = LeanObjectPtr
return fn(size, capacity)
Expand All @@ -557,15 +593,17 @@ def lean_box_uint64(self, v):
# Use the exported lean_box_uint64 if present.
fn = getattr(self.lib, "lean_box_uint64", None)
if fn is not None:
fn.argtypes = [c_uint64]; fn.restype = LeanObjectPtr
fn.argtypes = [c_uint64]
fn.restype = LeanObjectPtr
return fn(v)
# Fallback: scalar tagged pointer (only valid for small values).
return self.lean_box(v)

def lean_unbox_uint64(self, o):
fn = getattr(self.lib, "lean_unbox_uint64", None)
if fn is not None:
fn.argtypes = [LeanObjectPtr]; fn.restype = c_uint64
fn.argtypes = [LeanObjectPtr]
fn.restype = c_uint64
return fn(o)
return self.lean_unbox(o)

Expand All @@ -591,7 +629,8 @@ def lean_unsigned_to_nat(self, n):
fn = getattr(self.lib, "lean_unsigned_to_nat", None)
if fn is None:
return self.lean_box(n)
fn.argtypes = [c_uint]; fn.restype = LeanObjectPtr
fn.argtypes = [c_uint]
fn.restype = LeanObjectPtr
return fn(n)

def lean_uint64_to_nat(self, n):
Expand All @@ -602,9 +641,12 @@ def lean_uint64_to_nat(self, n):
return self.lean_box(n)
big = getattr(self.lib, "lean_big_uint64_to_nat", None)
if big is not None:
big.argtypes = [c_uint64]; big.restype = LeanObjectPtr
big.argtypes = [c_uint64]
big.restype = LeanObjectPtr
return big(c_uint64(n).value)
raise RuntimeError(f"Cannot convert uint64 {n} to Nat: lean_big_uint64_to_nat not found")
raise RuntimeError(
f"Cannot convert uint64 {n} to Nat: lean_big_uint64_to_nat not found"
)

def lean_uint64_of_nat(self, p):
# Inline: small scalar fast-path; large path via lean_uint64_of_big_nat.
Expand All @@ -613,29 +655,33 @@ def lean_uint64_of_nat(self, p):
fn = getattr(self.lib, "lean_uint64_of_big_nat", None)
if fn is None:
raise RuntimeError("lean_uint64_of_big_nat not found and Nat is not scalar")
fn.argtypes = [LeanObjectPtr]; fn.restype = c_uint64
fn.argtypes = [LeanObjectPtr]
fn.restype = c_uint64
return int(fn(p))

def lean_int64_to_int(self, n):
# Prefer the C-side helper from leanpy_native, which delegates
# to the static-inline `lean_int64_to_int` exactly.
fn = self._find_leanpy_helper("leanpy_int64_to_int")
if fn is not None:
fn.argtypes = [c_int64]; fn.restype = LeanObjectPtr
fn.argtypes = [c_int64]
fn.restype = LeanObjectPtr
return fn(n)
# Fallback (mirrors `lean.h`): encode int32-range as a scalar.
if -(1 << 31) <= n <= (1 << 31) - 1:
return self.lean_box(n & 0xFFFFFFFF)
big = getattr(self.lib, "lean_big_int64_to_int", None)
if big is None:
return self.lean_box(n & ((1 << 63) - 1))
big.argtypes = [c_int64]; big.restype = LeanObjectPtr
big.argtypes = [c_int64]
big.restype = LeanObjectPtr
return big(n)

def lean_int64_of_int(self, p):
fn = self._find_leanpy_helper("leanpy_int64_of_int")
if fn is not None:
fn.argtypes = [LeanObjectPtr]; fn.restype = c_int64
fn.argtypes = [LeanObjectPtr]
fn.restype = c_int64
return int(fn(p))
if self.lean_is_scalar(p):
return int(self.lean_scalar_to_int(p))
Expand All @@ -658,7 +704,6 @@ def lean_scalar_to_int(self, p):

def _add_helper_methods(class_dict: dict, structs: dict):
"""Add convenience helper methods."""
LeanObjectPtr = structs["_LeanObjectPtr"]

def mk_string(self, s):
"""Create a Lean string from a Python string."""
Expand Down
2 changes: 1 addition & 1 deletion lean_py/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from ctypes import POINTER, _Pointer
from ctypes import _Pointer

from lean_py._runtime import get_structs, get_constants

Expand Down
Loading
Loading