Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions tropical_in_new/src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
"""Tropical tensor network tools for MPE (independent package)."""

from .contraction import build_contraction_tree, choose_order, contract_tree
from .contraction import (
build_contraction_tree,
choose_order,
contract_omeco_tree,
contract_tree,
get_omeco_tree,
)
from .mpe import mpe_tropical, recover_mpe_assignment
from .network import TensorNode, build_network
from .primitives import argmax_trace, safe_log, tropical_einsum
from .primitives import safe_log
from .tropical_einsum import (
Backpointer,
argmax_trace,
match_rule,
tropical_einsum,
tropical_reduce_max,
)
from .utils import (
Factor,
UAIModel,
Expand All @@ -14,6 +27,7 @@
)

__all__ = [
"Backpointer",
"Factor",
"TensorNode",
"UAIModel",
Expand All @@ -22,12 +36,16 @@
"build_network",
"build_tropical_factors",
"choose_order",
"contract_omeco_tree",
"contract_tree",
"get_omeco_tree",
"match_rule",
"mpe_tropical",
"read_evidence_file",
"read_model_file",
"read_model_from_string",
"recover_mpe_assignment",
"safe_log",
"tropical_einsum",
"tropical_reduce_max",
]
195 changes: 148 additions & 47 deletions tropical_in_new/src/contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import omeco

from .network import TensorNode
from .primitives import Backpointer, tropical_reduce_max
from .utils import build_index_map
from .tropical_einsum import tropical_einsum, tropical_reduce_max, Backpointer


@dataclass
Expand All @@ -36,12 +35,6 @@ class ReduceNode:
TreeNode = TensorNode | ContractNode | ReduceNode


@dataclass(frozen=True)
class ContractionTree:
order: Tuple[int, ...]
nodes: Tuple[TensorNode, ...]


def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
sizes: dict[int, int] = {}
for node in nodes:
Expand All @@ -54,16 +47,128 @@ def _infer_var_sizes(nodes: Iterable[TensorNode]) -> dict[int, int]:
return sizes


def _extract_leaf_index(node_dict: dict) -> int | None:
for key in ("leaf", "leaf_index", "index", "tensor"):
if key in node_dict:
value = node_dict[key]
if isinstance(value, int):
return value
return None
def get_omeco_tree(nodes: list[TensorNode]) -> dict:
"""Get the optimized contraction tree from omeco.

Args:
nodes: List of tensor nodes to contract.

Returns:
The omeco tree as a dictionary with structure:
- Leaf: {"tensor_index": int}
- Node: {"args": [...], "eins": {"ixs": [[...], ...], "iy": [...]}}
"""
ixs = [list(node.vars) for node in nodes]
sizes = _infer_var_sizes(nodes)
method = omeco.GreedyMethod()
tree = omeco.optimize_code(ixs, [], sizes, method)
return tree.to_dict()


def contract_omeco_tree(
tree_dict: dict,
nodes: list[TensorNode],
track_argmax: bool = True,
) -> TreeNode:
"""Contract tensors following omeco's optimized tree structure.

Uses tropical-gemm for accelerated binary contractions when available.

Args:
tree_dict: The omeco tree dictionary from get_omeco_tree().
nodes: List of input tensor nodes.
track_argmax: Whether to track argmax for MPE backtracing.

Returns:
Root TreeNode with contracted result and backpointers.
"""

def recurse(node: dict) -> TreeNode:
# Leaf node - return the input tensor
if "tensor_index" in node:
return nodes[node["tensor_index"]]

# Internal node - contract children
args = node["args"]
eins = node["eins"]
out_vars = tuple(eins["iy"])

# Recursively contract children
children = [recurse(arg) for arg in args]

# Use tropical_einsum for the contraction
tensors = [c.values for c in children]
child_ixs = [c.vars for c in children]

values, backpointer = tropical_einsum(
tensors, list(child_ixs), out_vars, track_argmax=track_argmax
)

# Build result node (for binary, use ContractNode)
if len(children) == 2:
all_input = set(children[0].vars) | set(children[1].vars)
elim_vars = tuple(v for v in all_input if v not in out_vars)

return ContractNode(
vars=out_vars,
values=values,
left=children[0],
right=children[1],
elim_vars=elim_vars,
backpointer=backpointer,
)
else:
# For n-ary, chain as binary
result = children[0]
for i, child in enumerate(children[1:], 1):
is_final = (i == len(children) - 1)
target_out = out_vars if is_final else tuple(dict.fromkeys(result.vars + child.vars))

step_tensors = [result.values, child.values]
step_ixs = [result.vars, child.vars]

step_values, step_bp = tropical_einsum(
step_tensors, list(step_ixs), target_out, track_argmax=track_argmax
)

all_input = set(result.vars) | set(child.vars)
elim_vars = tuple(v for v in all_input if v not in target_out)

result = ContractNode(
vars=target_out,
values=step_values,
left=result,
right=child,
elim_vars=elim_vars,
backpointer=step_bp,
)
return result

return recurse(tree_dict)


# =============================================================================
# Legacy API for backward compatibility
# =============================================================================

@dataclass(frozen=True)
class ContractionTree:
"""Legacy contraction tree structure."""
order: Tuple[int, ...]
nodes: Tuple[TensorNode, ...]


def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]:
"""Legacy: Select elimination order. Use get_omeco_tree() instead."""
if heuristic != "omeco":
raise ValueError("Only the 'omeco' heuristic is supported.")
tree_dict = get_omeco_tree(nodes)
ixs = [list(node.vars) for node in nodes]
return _elim_order_from_tree_dict(tree_dict, ixs)


def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[int]:
"""Extract elimination order from omeco tree (legacy support)."""
total_counts: dict[int, int] = {}
for vars in ixs:
for var in vars:
Expand All @@ -72,14 +177,13 @@ def _elim_order_from_tree_dict(tree_dict: dict, ixs: list[list[int]]) -> list[in
eliminated: set[int] = set()

def visit(node: dict) -> tuple[dict[int, int], list[int]]:
leaf_index = _extract_leaf_index(node)
if leaf_index is not None:
if "tensor_index" in node:
counts: dict[int, int] = {}
for var in ixs[leaf_index]:
for var in ixs[node["tensor_index"]]:
counts[var] = counts.get(var, 0) + 1
return counts, []

children = node.get("children", [])
children = node.get("args") or node.get("children", [])
if not isinstance(children, list) or not children:
return {}, []

Expand All @@ -106,59 +210,49 @@ def visit(node: dict) -> tuple[dict[int, int], list[int]]:
return order + remaining


def choose_order(nodes: list[TensorNode], heuristic: str = "omeco") -> list[int]:
"""Select elimination order over variable indices using omeco."""
if heuristic != "omeco":
raise ValueError("Only the 'omeco' heuristic is supported.")
ixs = [list(node.vars) for node in nodes]
sizes = _infer_var_sizes(nodes)
method = omeco.GreedyMethod() if hasattr(omeco, "GreedyMethod") else None
tree = (
omeco.optimize_code(ixs, [], sizes, method)
if method is not None
else omeco.optimize_code(ixs, [], sizes)
)
tree_dict = tree.to_dict() if hasattr(tree, "to_dict") else tree
if not isinstance(tree_dict, dict):
raise ValueError("omeco.optimize_code did not return a usable tree.")
return _elim_order_from_tree_dict(tree_dict, ixs)


def build_contraction_tree(order: Iterable[int], nodes: list[TensorNode]) -> ContractionTree:
"""Prepare a contraction plan from order and leaf nodes."""
"""Legacy: Prepare a contraction plan from order and leaf nodes."""
return ContractionTree(order=tuple(order), nodes=tuple(nodes))


def contract_tree(
tree: ContractionTree,
einsum_fn,
einsum_fn=None,
track_argmax: bool = True,
) -> TreeNode:
"""Contract along the tree using the tropical einsum."""
"""Legacy: Contract using elimination order. Use contract_omeco_tree() instead."""
active_nodes: list[TreeNode] = list(tree.nodes)

for var in tree.order:
bucket = [node for node in active_nodes if var in node.vars]
if not bucket:
continue
bucket_ids = {id(node) for node in bucket}
active_nodes = [node for node in active_nodes if id(node) not in bucket_ids]

combined: TreeNode = bucket[0]
for i, other in enumerate(bucket[1:]):
is_last = i == len(bucket) - 2
elim_vars = (var,) if is_last else ()
index_map = build_index_map(combined.vars, other.vars, elim_vars=elim_vars)
values, backpointer = einsum_fn(
combined.values, other.values, index_map,

# Use tropical_einsum
target_out = tuple(v for v in dict.fromkeys(combined.vars + other.vars) if v not in elim_vars)
values, backpointer = tropical_einsum(
[combined.values, other.values],
[combined.vars, other.vars],
target_out,
track_argmax=track_argmax if is_last else False,
)

combined = ContractNode(
vars=index_map.out_vars,
vars=target_out,
values=values,
left=combined,
right=other,
elim_vars=elim_vars,
backpointer=backpointer,
)

if var in combined.vars:
# Single-node bucket: eliminate via reduce
values, backpointer = tropical_reduce_max(
Expand All @@ -172,20 +266,27 @@ def contract_tree(
backpointer=backpointer,
)
active_nodes.append(combined)

while len(active_nodes) > 1:
left = active_nodes.pop(0)
right = active_nodes.pop(0)
index_map = build_index_map(left.vars, right.vars, elim_vars=())
values, _ = einsum_fn(left.values, right.values, index_map, track_argmax=False)
target_out = tuple(dict.fromkeys(left.vars + right.vars))
values, _ = tropical_einsum(
[left.values, right.values],
[left.vars, right.vars],
target_out,
track_argmax=False,
)
combined = ContractNode(
vars=index_map.out_vars,
vars=target_out,
values=values,
left=left,
right=right,
elim_vars=(),
backpointer=None,
)
active_nodes.append(combined)

if not active_nodes:
raise ValueError("Contraction produced no nodes.")
return active_nodes[0]
31 changes: 20 additions & 11 deletions tropical_in_new/src/mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

from typing import Dict, Iterable

from .contraction import ContractNode, ReduceNode, build_contraction_tree, choose_order
from .contraction import contract_tree as _contract_tree
from .contraction import (
ContractNode,
ReduceNode,
contract_omeco_tree,
get_omeco_tree,
)
from .network import TensorNode, build_network
from .primitives import argmax_trace, tropical_einsum, tropical_reduce_max
from .tropical_einsum import argmax_trace, tropical_reduce_max
from .utils import UAIModel, build_tropical_factors


Expand Down Expand Up @@ -70,16 +74,22 @@ def traverse(node, out_assignment: Dict[int, int]) -> None:
def mpe_tropical(
model: UAIModel,
evidence: Dict[int, int] | None = None,
order: Iterable[int] | None = None,
) -> tuple[Dict[int, int], float, Dict[str, int | tuple[int, ...]]]:
"""Return MPE assignment, score, and contraction metadata."""
"""Return MPE assignment, score, and contraction metadata.

Uses omeco for optimized contraction order and tropical-gemm for acceleration.
"""
evidence = evidence or {}
factors = build_tropical_factors(model, evidence)
nodes = build_network(factors)
if order is None:
order = choose_order(nodes, heuristic="omeco")
tree = build_contraction_tree(order, nodes)
root = _contract_tree(tree, einsum_fn=tropical_einsum)

# Get optimized contraction tree from omeco
tree_dict = get_omeco_tree(nodes)

# Contract using the optimized tree
root = contract_omeco_tree(tree_dict, nodes, track_argmax=True)

# Final reduction if there are remaining variables
if root.vars:
values, backpointer = tropical_reduce_max(
root.values, root.vars, tuple(root.vars), track_argmax=True
Expand All @@ -91,12 +101,11 @@ def mpe_tropical(
elim_vars=tuple(root.vars),
backpointer=backpointer,
)

assignment = recover_mpe_assignment(root)
assignment.update({int(k): int(v) for k, v in evidence.items()})
score = float(root.values.item())
info = {
"order": tuple(order),
"num_nodes": len(nodes),
"num_elims": len(tuple(order)),
}
return assignment, score, info
Loading