-
Notifications
You must be signed in to change notification settings - Fork 0
feat: add tropical einsum module with OMEinsum-style design and tropical-gemm acceleration #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…cal-gemm acceleration This PR addresses issue #63 by: - Creating a new tropical_einsum module following OMEinsum design patterns - Implementing rule-based dispatch for unary and binary contractions - Integrating tropical-gemm for accelerated maxplus matrix multiplication - Using omeco tree structure directly for optimized contraction order Key changes: - New tropical_einsum.py with Identity, TropicalSum, Permutedims, Diag, Tr rules - Binary contractions use tropical_gemm.maxplus_matmul_with_argmax_f64() - contract_omeco_tree() executes contractions following omeco's optimized tree - Simplified mpe_tropical() API (removed manual order parameter) - 40 new tests for tropical einsum, 78 total tests passing Fixes #63 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add the classic Asia (lung cancer) Bayesian network benchmark and tests that verify our MPE implementation against brute-force enumeration. - Asia network: 8 binary variables, standard benchmark from Lauritzen & Spiegelhalter (1988) - Brute-force verification ensures correctness - Tests with and without evidence 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #65 +/- ##
=======================================
Coverage 95.18% 95.18%
=======================================
Files 10 10
Lines 747 747
=======================================
Hits 711 711
Misses 36 36
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the tropical tensor contraction system to address issue #63 by adopting OMEinsum-style design patterns, integrating tropical-gemm acceleration, and using omeco's tree structure directly for optimized contraction order.
Changes:
- Introduces new
tropical_einsum.pymodule with rule-based dispatch for unary and binary tropical contractions - Integrates tropical-gemm library for accelerated maxplus matrix multiplication with argmax tracking
- Updates
contraction.pywith newget_omeco_tree()andcontract_omeco_tree()functions that work directly with omeco's tree structure - Simplifies
mpe.pyAPI by removing manualorderparameter in favor of automatic omeco optimization - Adds 45+ new tests including comprehensive tropical_einsum tests and Asia Bayesian network benchmark tests
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| tropical_in_new/src/tropical_einsum.py | New 550-line module implementing rule-based tropical einsum with tropical-gemm integration |
| tropical_in_new/src/contraction.py | Adds omeco tree extraction and contraction functions; legacy API preserved for backward compatibility |
| tropical_in_new/src/primitives.py | Adds unused tropical_contract_binary() function; original functions retained |
| tropical_in_new/src/mpe.py | Simplifies API to use omeco directly; removes manual order parameter |
| tropical_in_new/src/init.py | Updates exports to include new tropical_einsum functions |
| tropical_in_new/tests/test_tropical_einsum.py | 425 lines of new tests for unary/binary rules, argmax tracing, and tropical-gemm integration |
| tropical_in_new/tests/test_benchmarks.py | 193 lines of new benchmark tests using Asia Bayesian network |
| tropical_in_new/tests/test_mpe.py | Updates tests to use new automatic omeco-optimized API |
| tropical_in_new/tests/test_contraction.py | Adds tests for new omeco tree functions |
| tropical_in_new/tests/benchmarks/asia.uai | New benchmark data file for Asia Bayesian network (44 lines) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Benchmark and Verification ResultsMethodology1. Asia Network Benchmark (Classic Bayesian Network)
2. Verification Approach
ResultsTest CoverageVerification Scriptimport torch, itertools
from tropical_in_new.src.mpe import mpe_tropical
from tropical_in_new.src.utils import read_model_file
model = read_model_file('tropical_in_new/tests/benchmarks/asia.uai', factor_eltype=torch.float64)
assignment, score, _ = mpe_tropical(model)
# Brute force verification
best_score, best_assignment = float('-inf'), None
for combo in itertools.product(*(range(c) for c in model.cards)):
log_score = sum(torch.log(f.values[tuple(combo[v-1] for v in f.vars)]).item()
for f in model.factors)
if log_score > best_score:
best_score, best_assignment = log_score, {i+1: combo[i] for i in range(model.nvars)}
assert assignment == best_assignment # ✅ Exact match
assert abs(score - best_score) < 1e-6 # ✅ Score within tolerance |
Updated Benchmark Results with UAI Competition DataAdded 10 benchmark files from UAI 2008/2014 competitions to Benchmark Results
Known Issue FoundDisconnected graphs: When a graphical model has disconnected components (isolated variables or separate subgraphs), the current implementation only returns assignments for variables that are part of the main connected component. Example: This is a limitation to address in a future PR. For now, the implementation works correctly for connected graphs. Files Added
Source: UAI Inference Competition, |
Add 10 benchmark files from UAI 2008/2014 inference competitions: - Protein folding: pdb1akg, pdb1etl, pdb1etn - Grid models: Grids_11, Grids_12, grid10x10.f10 - SAT: 2bitcomp_5.cnf, sat-grid-pbl-0010.cnf - CSP: CSP_12 - Pedigree: Pedigree_11 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove unused imports (Enum, auto) from tropical_einsum.py - Add validation to tropical_reduce_max for invalid elim_vars - Add validation to argmax_trace for missing assignment values - Remove unused variable (ixs) in contraction.py - Remove dead code (tropical_contract_binary) from primitives.py - Remove unused imports from test_tropical_einsum.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add n-ary contraction support in DefaultRule (chains as binary ops) - Consolidate Backpointer: import from primitives.py instead of duplicate - Add Diag test coverage (test_diag_extraction, test_diag_with_extra_dim) - Improve test_contract_omeco_tree_matches_legacy to verify correctness - Add performance comment about sequential batched GEMM processing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Users should use contraction order optimization (omeco) to decompose n-ary contractions into binary contractions first. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
This PR addresses issue #63 by refactoring the tropical tensor contraction system to:
Changes
New Module:
tropical_einsum.pyIdentity,TropicalSum,Permutedims,Diag,Tr,SimpleBinaryRule,DefaultRulematch_rule()dispatches to optimized rules based on contraction patterntropical_gemm.maxplus_matmul_with_argmax_f64()for accelerationUpdated Files
contraction.py: Newget_omeco_tree()andcontract_omeco_tree()functionsmpe.py: Simplified API using omeco tree directly (removed manualorderparameter)primitives.py: Addedtropical_contract_binary()functionTest Coverage
tropical_einsum(unary/binary rules, rule matching, argmax tracing)Test plan
Fixes #63
🤖 Generated with Claude Code