Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 127 additions & 15 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt
from .inner import (
KApply,
KAs,
KInner,
KLabel,
KRewrite,
Expand Down Expand Up @@ -1327,6 +1328,8 @@ def sort(self, kast: KInner) -> KSort | None:
match kast:
case KToken(_, sort) | KVariable(_, sort):
return sort
case KAs(alias=KVariable(sort=sort)):
return sort
case KRewrite(lhs, rhs):
lhs_sort = self.sort(lhs)
rhs_sort = self.sort(rhs)
Expand All @@ -1336,8 +1339,11 @@ def sort(self, kast: KInner) -> KSort | None:
case KSequence(_):
return KSort('K')
case KApply(label, _):
sort, _ = self.resolve_sorts(label)
return sort
try:
sort, _ = self.resolve_sorts(label)
return sort
except (KeyError, ValueError):
return None
case _:
return None

Expand All @@ -1354,7 +1360,13 @@ def resolve_sorts(self, label: KLabel) -> tuple[KSort, tuple[KSort, ...]]:
sorts = dict(zip(prod.params, label.params, strict=True))

def resolve(sort: KSort) -> KSort:
return sorts.get(sort, sort)
# Direct match: sort IS one of the sort parameters.
if sort in sorts:
return sorts[sort]
# Recursive substitution: sort params may appear nested (e.g. MInt{Width} → MInt{8}).
if sort.params:
return KSort(sort.name, tuple(resolve(p) for p in sort.params))
return sort

return resolve(prod.sort), tuple(resolve(sort) for sort in prod.argument_sorts)

Expand Down Expand Up @@ -1483,28 +1495,108 @@ def transform(
# Best-effort addition of sort parameters to klabels, context insensitive
def add_sort_params(self, kast: KInner) -> KInner:
"""Return a given term with the sort parameters on the `KLabel` filled in (which may be missing because of how the frontend works), best effort."""
# ML predicate labels whose result sort (Sort2) is context-dependent and not inferable
# from the arguments alone. When Sort1 can be determined but Sort2 cannot, we fill Sort2
# with the sentinel KSort('#SortParam') so that downstream Kore emission can introduce a
# universally-quantified sort variable (Q0) in the axiom.
_ML_PRED_RESULT_SORT_PARAM = KSort('#SortParam') # noqa: N806
_ML_PRED_LABELS = frozenset({'#Equals', '#Ceil', '#Floor', '#In'}) # noqa: N806

Comment thread
ehildenb marked this conversation as resolved.
def _unify_sort_params(parametric: KSort, actual: KSort, params: frozenset[KSort]) -> dict[KSort, KSort]:
"""Match parametric sort against actual, extracting bindings for sort params.

Handles both direct (parametric IS a sort param) and nested
(parametric = MInt{Width}, actual = MInt{8}) cases.
Returns empty dict when no bindings could be extracted (no match).
"""
if parametric in params:
return {parametric: actual}
if parametric.name != actual.name or len(parametric.params) != len(actual.params):
return {}
result: dict[KSort, KSort] = {}
for p_sub, a_sub in zip(parametric.params, actual.params, strict=True):
sub_bindings = _unify_sort_params(p_sub, a_sub, params)
for k, v in sub_bindings.items():
if k in result and result[k] != v:
return {} # Conflicting bindings
result[k] = v
return result

def _merge_binding(sort_dict: dict[KSort, KSort], k: KSort, v: KSort) -> bool:
"""Merge one binding into sort_dict in place. Returns False on irreconcilable conflict."""
if k in sort_dict:
existing = sort_dict[k]
if existing == _ML_PRED_RESULT_SORT_PARAM:
sort_dict[k] = v # Concrete sort overrides sentinel.
elif existing != v:
lub = self.least_common_supersort(existing, v)
if lub is None:
_LOGGER.warning(f'Failed to add sort parameter, sort mismatch: {(k, existing, v)}')
return False
sort_dict[k] = lub
else:
sort_dict[k] = v
return True

def _add_sort_params(_k: KInner) -> KInner:
if type(_k) is KApply:
prod = self.symbols[_k.label.name]
if len(_k.label.params) == 0 and len(prod.params) > 0:
param_set = frozenset(prod.params)
sort_dict: dict[KSort, KSort] = {}
for psort, asort in zip(prod.argument_sorts, map(self.sort, _k.args), strict=True):
if asort == _ML_PRED_RESULT_SORT_PARAM:
# #SortParam is the sentinel for an ML pred result sort that cannot be
# inferred bottom-up (e.g. #Equals result sort depends on outer context).
# It propagates upward into ML connectives (#And, #Or, #Not) as a
# placeholder for the axiom sort variable Q0, but a concrete sort takes
# precedence when one is available.
bindings = _unify_sort_params(psort, asort, param_set)
for k, v in bindings.items():
if k not in sort_dict: # sentinel fills only empty slots
sort_dict[k] = v
continue
if asort is None:
_LOGGER.warning(
f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}'
)
return _k
if psort in prod.params:
if psort in sort_dict and sort_dict[psort] != asort:
_LOGGER.warning(
f'Failed to add sort parameter, sort mismatch between different occurances of sort parameter: {(prod, psort, sort_dict[psort], asort)}'
)
# Unify psort with asort to extract bindings for sort params.
# Handles both direct (psort=Width) and nested (psort=MInt{Width}) cases.
bindings = _unify_sort_params(psort, asort, param_set)
for k, v in bindings.items():
if not _merge_binding(sort_dict, k, v):
return _k
elif psort not in sort_dict:
sort_dict[psort] = asort
if all(p in sort_dict for p in prod.params):
return _k.let(label=KLabel(_k.label.name, [sort_dict[p] for p in prod.params]))
# ML predicates have a context-dependent result sort (Sort2) that cannot be
# inferred from arguments. Fill it with the sentinel so that krule_to_kore can
# introduce a universally-quantified sort variable for the axiom.
if _k.label.name in _ML_PRED_LABELS:
unbound = [p for p in prod.params if p not in sort_dict]
# The single sentinel KSort('#SortParam') is only unambiguous when at most
# one parameter is unresolvable bottom-up. All current ML predicates
# (#Equals, #Ceil, #Floor, #In) have exactly two sort params {Sort1,
# Sort2}: Sort1 is always determined by the arguments, Sort2 (the result
# sort) is the one remaining unbound param. If more than one param is
# unbound, the sentinel scheme must be replaced with unique fresh params
# (e.g. KSort('#SortParam', (KSort('Q0'),)), KSort('#SortParam', (KSort('Q1'),)), ...)
# analogously to how Java's AddSortInjections generates #SortParam{Q0},
# #SortParam{Q1}, etc. _ksort_to_kore would also need updating to emit
# these as sort variables rather than sort applications.
if len(unbound) > 1:
raise NotImplementedError(
f'ML predicate {_k.label.name!r} has {len(unbound)} unbound sort parameters '
f'({unbound}); the single-sentinel scheme only handles at most one. '
f'Implement unique fresh sentinels analogous to Java #SortParam{{Q0}}, '
f'#SortParam{{Q1}}, ... and update _ksort_to_kore to emit them as sort variables.'
)
filled = {p: sort_dict.get(p, _ML_PRED_RESULT_SORT_PARAM) for p in prod.params}
return _k.let(label=KLabel(_k.label.name, [filled[p] for p in prod.params]))
unbound = [p for p in prod.params if p not in sort_dict]
_LOGGER.warning(
f'Failed to add sort parameter, could not infer sort params from arguments: {(prod, unbound)}'
)
return _k

return bottom_up(_add_sort_params, kast)
Expand All @@ -1515,15 +1607,35 @@ def add_cell_map_items(self, kast: KInner) -> KInner:
# syntax AccountCellMap [cellCollection, hook(MAP.Map)]
# syntax AccountCellMap ::= AccountCellMap AccountCellMap [assoc, avoid, cellCollection, comm, element(AccountCellMapItem), function, hook(MAP.concat), unit(.AccountCellMap), wrapElement(<account>)]

cell_wrappers = {}
# Maps cell label -> (element_constructor, cell_map_sort).
# Wrapping is correct only when the parent production expects the cell MAP sort (e.g.
# EntryCellMap), not when it expects the individual cell element sort (e.g. EntryCell).
# For example, EntryCellMapKey(<entry>(...)) takes EntryCell — the <entry> must NOT be
# wrapped, whereas _EntryCellMap_(<entry>(...), ...) expects EntryCellMap — wrapping is needed.
cell_wrappers: dict[str, tuple[str, KSort]] = {}
for ccp in self.cell_collection_productions:
if Atts.ELEMENT in ccp.att and Atts.WRAP_ELEMENT in ccp.att:
cell_wrappers[ccp.att[Atts.WRAP_ELEMENT]] = ccp.att[Atts.ELEMENT]
cell_label = ccp.att[Atts.WRAP_ELEMENT]
element_ctor = ccp.att[Atts.ELEMENT]
if element_ctor in self.symbols:
cell_wrappers[cell_label] = (element_ctor, self.symbols[element_ctor].sort)

def _wrap_elements(_k: KInner) -> KInner:
if type(_k) is KApply and _k.label.name in cell_wrappers:
return KApply(cell_wrappers[_k.label.name], [_k.args[0], _k])
return _k
if not isinstance(_k, KApply) or _k.label.name not in self.symbols:
return _k
prod = self.symbols[_k.label.name]
arg_sorts = prod.argument_sorts
if not arg_sorts or len(arg_sorts) != _k.arity:
return _k
new_args: list[KInner] = list(_k.args)
changed = False
for i, (arg_sort, arg) in enumerate(zip(arg_sorts, _k.args, strict=True)):
if isinstance(arg, KApply) and arg.label.name in cell_wrappers:
element_ctor, cell_map_sort = cell_wrappers[arg.label.name]
if arg_sort == cell_map_sort:
new_args[i] = KApply(element_ctor, [arg.args[0], arg])
changed = True
return _k.let(args=new_args) if changed else _k

# To ensure we don't get duplicate wrappers.
_kast = self.remove_cell_map_items(kast)
Expand Down
Loading
Loading