Skip to content
56 changes: 48 additions & 8 deletions kmir/src/kmir/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pyk.cli.args import KCLIArgs
from pyk.cterm.show import CTermShow
from pyk.kast.pretty import PrettyPrinter
from pyk.kdist import kdist
from pyk.proof.reachability import APRProof
from pyk.proof.show import APRProofShow
from pyk.proof.tui import APRProofViewer

from .build import HASKELL_DEF_DIR, LLVM_LIB_DIR
from .cargo import CargoProject
from .kmir import KMIR, KMIRAPRNodePrinter
from .linker import link
Expand Down Expand Up @@ -54,7 +54,14 @@ def _kmir_run(opts: RunOpts) -> None:
smir_info = cargo.smir_for_project(clean=False)

def run(target_dir: Path):
kmir = KMIR.from_kompiled_kore(smir_info, symbolic=opts.haskell_backend, target_dir=target_dir)
kmir = KMIR.from_kompiled_kore(
smir_info,
target_dir=target_dir,
symbolic=opts.symbolic,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
llvm_target=opts.llvm_target,
)
result = kmir.run_smir(smir_info, start_symbol=opts.start_symbol, depth=opts.depth)
print(kmir.kore_to_pretty(result))

Expand All @@ -73,7 +80,10 @@ def _kmir_prove_rs(opts: ProveRSOpts) -> None:


def _kmir_view(opts: ViewOpts) -> None:
kmir = KMIR(HASKELL_DEF_DIR, LLVM_LIB_DIR)
kmir = KMIR(
definition_dir=kdist.which(opts.haskell_target or 'mir-semantics.haskell'),
llvm_library_dir=kdist.which(opts.llvm_lib_target or 'mir-semantics.llvm-library'),
)
proof = APRProof.read_proof_data(opts.proof_dir, opts.id)
printer = PrettyPrinter(kmir.definition)
omit_labels = ('<currentBody>',) if opts.omit_current_body else ()
Expand Down Expand Up @@ -118,7 +128,10 @@ def _kmir_show(opts: ShowOpts) -> None:

from .kprint import KMIRPrettyPrinter

kmir = KMIR(HASKELL_DEF_DIR, LLVM_LIB_DIR)
kmir = KMIR(
definition_dir=kdist.which(opts.haskell_target or 'mir-semantics.haskell'),
llvm_library_dir=kdist.which(opts.llvm_lib_target or 'mir-semantics.llvm-library'),
)
proof = APRProof.read_proof_data(opts.proof_dir, opts.id)

# Minimize proof KCFG if requested
Expand Down Expand Up @@ -198,7 +211,14 @@ def _kmir_section_edge(opts: SectionEdgeOpts) -> None:

smir_info = SMIRInfo.from_file(target_path / 'smir.json')

kmir = KMIR.from_kompiled_kore(smir_info, symbolic=True, bug_report=opts.bug_report, target_dir=target_path)
kmir = KMIR.from_kompiled_kore(
smir_info,
target_dir=target_path,
bug_report=opts.bug_report,
symbolic=True,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
)

source_id, target_id = opts.edge
_LOGGER.info(f'Attempting to add {opts.sections} sections from node {source_id} to node {target_id}')
Expand Down Expand Up @@ -271,7 +291,10 @@ def _arg_parser() -> ArgumentParser:
run_parser.add_argument(
'--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from'
)
run_parser.add_argument('--haskell-backend', action='store_true', help='Run with the haskell backend')
run_parser.add_argument('--symbolic', action='store_true', help='Run with the symbolic backend')
run_parser.add_argument('--haskell-target', metavar='TARGET', help='Haskell target to use')
run_parser.add_argument('--llvm-lib-target', metavar='TARGET', help='LLVM lib target to use')
run_parser.add_argument('--llvm-target', metavar='TARGET', help='LLVM target to use')

info_parser = command_parser.add_parser(
'info', help='Show information about a SMIR JSON file', parents=[kcli_args.logging_args]
Expand All @@ -281,6 +304,8 @@ def _arg_parser() -> ArgumentParser:

prove_args = ArgumentParser(add_help=False)
prove_args.add_argument('--proof-dir', metavar='DIR', help='Proof directory')
prove_args.add_argument('--haskell-target', metavar='TARGET', help='Haskell target to use')
prove_args.add_argument('--llvm-lib-target', metavar='TARGET', help='LLVM lib target to use')
prove_args.add_argument('--bug-report', metavar='PATH', help='path to optional bug report')
prove_args.add_argument('--max-depth', metavar='DEPTH', type=int, help='max steps to take between nodes in kcfg')
prove_args.add_argument(
Expand Down Expand Up @@ -412,6 +437,8 @@ def _arg_parser() -> ArgumentParser:
action='store_false',
help='Display the <currentBody> cell completely.',
)
display_args.add_argument('--haskell-target', metavar='TARGET', help='Haskell target to use')
display_args.add_argument('--llvm-lib-target', metavar='TARGET', help='LLVM lib target to use')

show_parser = command_parser.add_parser(
'show', help='Show proof information', parents=[kcli_args.logging_args, proof_args, display_args]
Expand Down Expand Up @@ -482,6 +509,8 @@ def _arg_parser() -> ArgumentParser:
section_edge_parser.add_argument(
'--sections', type=int, default=2, help='Number of sections to make from edge (>= 2, default: 2)'
)
section_edge_parser.add_argument('--haskell-target', metavar='TARGET', help='Haskell target to use')
section_edge_parser.add_argument('--llvm-lib-target', metavar='TARGET', help='LLVM lib target to use')

prove_rs_parser = command_parser.add_parser(
'prove-rs', help='Prove a rust program', parents=[kcli_args.logging_args, prove_args]
Expand Down Expand Up @@ -523,7 +552,7 @@ def _parse_args(ns: Namespace) -> KMirOpts:
target_dir=ns.target_dir,
depth=ns.depth,
start_symbol=ns.start_symbol,
haskell_backend=ns.haskell_backend,
symbolic=ns.symbolic,
)
case 'info':
return InfoOpts(smir_file=Path(ns.smir_file), types=ns.types)
Expand All @@ -533,6 +562,8 @@ def _parse_args(ns: Namespace) -> KMirOpts:
id=ns.id,
full_printer=ns.full_printer,
smir_info=Path(ns.smir_info) if ns.smir_info else None,
haskell_target=ns.haskell_target,
llvm_lib_target=ns.llvm_lib_target,
omit_current_body=ns.omit_current_body,
nodes=ns.nodes,
node_deltas=ns.node_deltas,
Expand All @@ -551,6 +582,8 @@ def _parse_args(ns: Namespace) -> KMirOpts:
ns.id,
full_printer=ns.full_printer,
smir_info=ns.smir_info,
haskell_target=ns.haskell_target,
llvm_lib_target=ns.llvm_lib_target,
omit_current_body=ns.omit_current_body,
)
case 'prune':
Expand All @@ -560,7 +593,14 @@ def _parse_args(ns: Namespace) -> KMirOpts:
if ns.proof_dir is None:
raise ValueError('Must pass --proof-dir to section-edge command')
proof_dir = Path(ns.proof_dir)
return SectionEdgeOpts(proof_dir, ns.id, ns.edge, ns.sections)
return SectionEdgeOpts(
proof_dir,
ns.id,
ns.edge,
sections=ns.sections,
haskell_target=ns.haskell_target,
llvm_lib_target=ns.llvm_lib_target,
)
case 'prove-rs':
return ProveRSOpts(
rs_file=Path(ns.rs_file),
Expand Down
23 changes: 17 additions & 6 deletions kmir/src/kmir/kmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,25 @@ def cut_point_rules(
def from_kompiled_kore(
smir_info: SMIRInfo,
target_dir: Path,
*,
extra_module: Path | None = None,
bug_report: Path | None = None,
symbolic: bool = True,
extra_module: Path | None = None,
llvm_target: str | None = None,
llvm_lib_target: str | None = None,
haskell_target: str | None = None,
) -> KMIR:
from .kompile import kompile_smir

kompiled_smir = kompile_smir(
smir_info=smir_info,
target_dir=target_dir,
extra_module=extra_module,
bug_report=bug_report,
symbolic=symbolic,
extra_module=extra_module,
llvm_target=llvm_target,
llvm_lib_target=llvm_lib_target,
haskell_target=haskell_target,
)
return kompiled_smir.create_kmir(bug_report_file=bug_report)

Expand Down Expand Up @@ -219,10 +226,12 @@ def prove_rs(opts: ProveRSOpts) -> APRProof:
smir_info = SMIRInfo.from_file(target_path / 'smir.json')
kmir = KMIR.from_kompiled_kore(
smir_info,
symbolic=True,
bug_report=opts.bug_report,
target_dir=target_path,
extra_module=opts.add_module,
bug_report=opts.bug_report,
symbolic=True,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
)
else:
_LOGGER.info(f'Constructing initial proof: {label}')
Expand All @@ -247,10 +256,12 @@ def prove_rs(opts: ProveRSOpts) -> APRProof:

kmir = KMIR.from_kompiled_kore(
smir_info,
symbolic=True,
bug_report=opts.bug_report,
target_dir=target_path,
extra_module=opts.add_module,
bug_report=opts.bug_report,
symbolic=True,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
)

apr_proof = kmir.apr_proof_from_smir(
Expand Down
50 changes: 39 additions & 11 deletions kmir/src/kmir/kompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pyk.kast.inner import KApply, KSort, KToken, KVariable
from pyk.kast.prelude.kint import intToken
from pyk.kast.prelude.string import stringToken
from pyk.kdist import kdist
from pyk.kore.syntax import App, EVar, SortApp, String, Symbol, SymbolDecl

from .build import HASKELL_DEF_DIR, LLVM_DEF_DIR, LLVM_LIB_DIR
from .kmir import KMIR

if TYPE_CHECKING:
Expand Down Expand Up @@ -62,6 +62,9 @@ def create_kmir(self, *, bug_report_file: Path | None = None) -> KMIR:
class KompileDigest:
digest: str
symbolic: bool
llvm_target: str
llvm_lib_target: str
haskell_target: str

@staticmethod
def load(target_dir: Path) -> KompileDigest:
Expand All @@ -74,6 +77,9 @@ def load(target_dir: Path) -> KompileDigest:
return KompileDigest(
digest=data['digest'],
symbolic=data['symbolic'],
llvm_target=data['llvm-target'],
llvm_lib_target=data['llvm-lib-target'],
haskell_target=data['haskell-target'],
)

def write(self, target_dir: Path) -> None:
Expand All @@ -82,6 +88,9 @@ def write(self, target_dir: Path) -> None:
{
'digest': self.digest,
'symbolic': self.symbolic,
'llvm-target': self.llvm_target,
'llvm-lib-target': self.llvm_lib_target,
'haskell-target': self.haskell_target,
},
),
)
Expand Down Expand Up @@ -128,21 +137,37 @@ def _load_extra_module_rules(kmir: KMIR, module_path: Path) -> list[Sentence]:
def kompile_smir(
smir_info: SMIRInfo,
target_dir: Path,
*,
bug_report: Path | None = None,
symbolic: bool = True,
extra_module: Path | None = None,
symbolic: bool = True,
llvm_target: str | None = None,
llvm_lib_target: str | None = None,
haskell_target: str | None = None,
) -> KompiledSMIR:
kompile_digest: KompileDigest | None = None
try:
kompile_digest = KompileDigest.load(target_dir)
except Exception:
pass

llvm_target = llvm_target or 'mir-semantics.llvm'
llvm_lib_target = llvm_lib_target or 'mir-semantics.llvm-library'
haskell_target = haskell_target or 'mir-semantics.haskell'

expected_digest = KompileDigest(
digest=smir_info.digest,
symbolic=symbolic,
llvm_target=llvm_target,
llvm_lib_target=llvm_lib_target,
haskell_target=haskell_target,
)

target_hs_path = target_dir / 'haskell'
target_llvm_lib_path = target_dir / 'llvm-library'
target_llvm_path = target_dir / 'llvm'

if kompile_digest is not None and kompile_digest.digest == smir_info.digest and kompile_digest.symbolic == symbolic:
if kompile_digest == expected_digest:
_LOGGER.info(f'Kompiled SMIR up-to-date, no kompilation necessary: {target_dir}')
if symbolic:
return KompiledSymbolic(haskell_dir=target_hs_path, llvm_lib_dir=target_llvm_lib_path)
Expand All @@ -151,10 +176,11 @@ def kompile_smir(

_LOGGER.info(f'Kompiling SMIR program: {target_dir}')

kompile_digest = KompileDigest(digest=smir_info.digest, symbolic=symbolic)
kompile_digest = expected_digest
target_dir.mkdir(parents=True, exist_ok=True)

kmir = KMIR(HASKELL_DEF_DIR)
haskell_def_dir = kdist.which(haskell_target)
kmir = KMIR(haskell_def_dir)
smir_rules: list[Sentence] = list(make_kore_rules(kmir, smir_info))
_LOGGER.info(f'Generated {len(smir_rules)} function equations to add to `definition.kore')

Expand All @@ -179,7 +205,8 @@ def kompile_smir(
# Process LLVM definition (only SMIR rules, not extra module rules)
# Extra module rules are configuration rewrites that LLVM backend doesn't support
_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_lib_dir = kdist.which(llvm_lib_target)
llvm_def_file = llvm_lib_dir / 'definition.kore'
llvm_def_output = target_llvm_lib_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, smir_rules, llvm_def_output)

Expand Down Expand Up @@ -209,12 +236,12 @@ def kompile_smir(

# Process Haskell definition (includes both SMIR rules and extra module rules)
_LOGGER.info('Writing Haskell definition file')
hs_def_file = HASKELL_DEF_DIR / 'definition.kore'
hs_def_file = haskell_def_dir / 'definition.kore'
_insert_rules_and_write(hs_def_file, all_rules, target_hs_path / 'definition.kore')

# Copy all files except definition.kore and binary from HASKELL_DEF_DIR to out/hs
_LOGGER.info('Copying other artefacts into HS output directory')
for file_path in HASKELL_DEF_DIR.iterdir():
for file_path in haskell_def_dir.iterdir():
if file_path.name != 'definition.kore' and file_path.name != 'haskellDefinition.bin':
if file_path.is_file():
shutil.copy2(file_path, target_hs_path / file_path.name)
Expand All @@ -231,7 +258,8 @@ def kompile_smir(

# Process LLVM definition (only SMIR rules for concrete execution)
_LOGGER.info('Writing LLVM definition file')
llvm_def_file = LLVM_LIB_DIR / 'definition.kore'
llvm_def_dir = kdist.which(llvm_target)
llvm_def_file = llvm_def_dir / 'definition.kore'
llvm_def_output = target_llvm_path / 'definition.kore'
_insert_rules_and_write(llvm_def_file, smir_rules, llvm_def_output)

Expand All @@ -256,10 +284,10 @@ def kompile_smir(
check=True,
)
blacklist = ['definition.kore', 'interpreter', 'dt']
to_copy = [file.name for file in LLVM_DEF_DIR.iterdir() if file.name not in blacklist]
to_copy = [file.name for file in llvm_def_dir.iterdir() if file.name not in blacklist]
for file in to_copy:
_LOGGER.info(f'Copying file {file}')
shutil.copy2(LLVM_DEF_DIR / file, target_llvm_path / file)
shutil.copy2(llvm_def_dir / file, target_llvm_path / file)

kompile_digest.write(target_dir)
return KompiledConcrete(llvm_dir=target_llvm_path)
Expand Down
Loading