Skip to content

Conversation

@GiggleLiu
Copy link
Member

Summary

This PR addresses issue #63 by refactoring the tropical tensor contraction system to:

  • Follow OMEinsum design patterns with rule-based dispatch for unary and binary contractions
  • Integrate tropical-gemm for accelerated maxplus matrix multiplication
  • Use omeco tree structure directly for optimized contraction order (no more elimination order conversion)

Changes

New Module: tropical_einsum.py

  • Rule Types: Identity, TropicalSum, Permutedims, Diag, Tr, SimpleBinaryRule, DefaultRule
  • Rule Matching: match_rule() dispatches to optimized rules based on contraction pattern
  • tropical-gemm Integration: Binary contractions use tropical_gemm.maxplus_matmul_with_argmax_f64() for acceleration
  • Backpointer Tracking: Full argmax tracking for MPE recovery

Updated Files

  • contraction.py: New get_omeco_tree() and contract_omeco_tree() functions
  • mpe.py: Simplified API using omeco tree directly (removed manual order parameter)
  • primitives.py: Added tropical_contract_binary() function

Test Coverage

  • 40 new tests for tropical_einsum (unary/binary rules, rule matching, argmax tracing)
  • 5 benchmark tests using the classic Asia Bayesian network
  • 83 total tests passing

Test plan

  • All existing tests pass
  • New tropical_einsum tests cover unary and binary rules
  • MPE verified against brute-force enumeration on Asia benchmark
  • tropical-gemm integration tested

Fixes #63

🤖 Generated with Claude Code

GiggleLiu and others added 2 commits January 25, 2026 17:52
…cal-gemm acceleration

This PR addresses issue #63 by:
- Creating a new tropical_einsum module following OMEinsum design patterns
- Implementing rule-based dispatch for unary and binary contractions
- Integrating tropical-gemm for accelerated maxplus matrix multiplication
- Using omeco tree structure directly for optimized contraction order

Key changes:
- New tropical_einsum.py with Identity, TropicalSum, Permutedims, Diag, Tr rules
- Binary contractions use tropical_gemm.maxplus_matmul_with_argmax_f64()
- contract_omeco_tree() executes contractions following omeco's optimized tree
- Simplified mpe_tropical() API (removed manual order parameter)
- 40 new tests for tropical einsum, 78 total tests passing

Fixes #63

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add the classic Asia (lung cancer) Bayesian network benchmark and tests
that verify our MPE implementation against brute-force enumeration.

- Asia network: 8 binary variables, standard benchmark from Lauritzen & Spiegelhalter (1988)
- Brute-force verification ensures correctness
- Tests with and without evidence

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 95.18%. Comparing base (13cb07b) to head (0063a62).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #65   +/-   ##
=======================================
  Coverage   95.18%   95.18%           
=======================================
  Files          10       10           
  Lines         747      747           
=======================================
  Hits          711      711           
  Misses         36       36           
Flag Coverage Δ
unittests 95.18% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the tropical tensor contraction system to address issue #63 by adopting OMEinsum-style design patterns, integrating tropical-gemm acceleration, and using omeco's tree structure directly for optimized contraction order.

Changes:

  • Introduces new tropical_einsum.py module with rule-based dispatch for unary and binary tropical contractions
  • Integrates tropical-gemm library for accelerated maxplus matrix multiplication with argmax tracking
  • Updates contraction.py with new get_omeco_tree() and contract_omeco_tree() functions that work directly with omeco's tree structure
  • Simplifies mpe.py API by removing manual order parameter in favor of automatic omeco optimization
  • Adds 45+ new tests including comprehensive tropical_einsum tests and Asia Bayesian network benchmark tests

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
tropical_in_new/src/tropical_einsum.py New 550-line module implementing rule-based tropical einsum with tropical-gemm integration
tropical_in_new/src/contraction.py Adds omeco tree extraction and contraction functions; legacy API preserved for backward compatibility
tropical_in_new/src/primitives.py Adds unused tropical_contract_binary() function; original functions retained
tropical_in_new/src/mpe.py Simplifies API to use omeco directly; removes manual order parameter
tropical_in_new/src/init.py Updates exports to include new tropical_einsum functions
tropical_in_new/tests/test_tropical_einsum.py 425 lines of new tests for unary/binary rules, argmax tracing, and tropical-gemm integration
tropical_in_new/tests/test_benchmarks.py 193 lines of new benchmark tests using Asia Bayesian network
tropical_in_new/tests/test_mpe.py Updates tests to use new automatic omeco-optimized API
tropical_in_new/tests/test_contraction.py Adds tests for new omeco tree functions
tropical_in_new/tests/benchmarks/asia.uai New benchmark data file for Asia Bayesian network (44 lines)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@GiggleLiu
Copy link
Member Author

Benchmark and Verification Results

Methodology

1. Asia Network Benchmark (Classic Bayesian Network)

  • Used the Asia network (aka "lung cancer" network) from Lauritzen & Spiegelhalter (1988)
  • 8 binary variables, 8 factors
  • Standard UAI format benchmark file: tropical_in_new/tests/benchmarks/asia.uai

2. Verification Approach

  • Compared tropical MPE against exhaustive brute-force enumeration
  • Brute force computes argmax Σ log(factor_values) over all 2^8 = 256 configurations
  • Verified both assignment AND score match

Results

Tropical MPE:
  assignment: {1: 1, 2: 1, 3: 1, 4: 1, 5: 0, 6: 0, 7: 0, 8: 0}
  score: -0.8589117469492968

Brute force:
  assignment: {1: 1, 2: 1, 3: 1, 4: 1, 5: 0, 6: 0, 7: 0, 8: 0}
  score: -0.8589117210358381

Difference: 2.59e-08 (floating point precision only)

Test Coverage

83 tests passed in 0.97s:
├── test_benchmarks.py (5 tests)
│   ├── test_asia_mpe_matches_brute_force ✅
│   ├── test_asia_mpe_with_evidence ✅
│   ├── test_asia_mpe_result_reasonable ✅
│   ├── test_mpe_simple_chain ✅
│   └── test_mpe_with_strong_evidence ✅
├── test_tropical_einsum.py (40 tests)
│   ├── Unary rules (Identity, Permutedims, TropicalSum, Trace)
│   ├── Binary rules (outer product, dot product, matmul variants)
│   ├── Rule matching
│   ├── Argmax tracing
│   └── tropical-gemm integration
├── test_mpe.py (6 tests)
├── test_contraction.py (17 tests)
├── test_primitives.py (11 tests)
└── test_utils.py (4 tests)

Verification Script

import torch, itertools
from tropical_in_new.src.mpe import mpe_tropical
from tropical_in_new.src.utils import read_model_file

model = read_model_file('tropical_in_new/tests/benchmarks/asia.uai', factor_eltype=torch.float64)
assignment, score, _ = mpe_tropical(model)

# Brute force verification
best_score, best_assignment = float('-inf'), None
for combo in itertools.product(*(range(c) for c in model.cards)):
    log_score = sum(torch.log(f.values[tuple(combo[v-1] for v in f.vars)]).item() 
                    for f in model.factors)
    if log_score > best_score:
        best_score, best_assignment = log_score, {i+1: combo[i] for i in range(model.nvars)}

assert assignment == best_assignment  # ✅ Exact match
assert abs(score - best_score) < 1e-6  # ✅ Score within tolerance

@GiggleLiu
Copy link
Member Author

Updated Benchmark Results with UAI Competition Data

Added 10 benchmark files from UAI 2008/2014 competitions to tropical_in_new/tests/benchmarks/:

Benchmark Results

Benchmark Vars Factors MPE Score Time (s)
asia.uai 8 8 -0.8589 0.0009
pdb1etl.uai 9 14 -4.6014 0.0002
pdb1akg.uai 14 25 -4.0556 0.0004
Grids_12.uai 100 280 696.78 0.0186

Known Issue Found

Disconnected graphs: When a graphical model has disconnected components (isolated variables or separate subgraphs), the current implementation only returns assignments for variables that are part of the main connected component.

Example: pdb1etn.uai has 9 variables but 4 are isolated - the MPE assignment only contains 4 variables.

This is a limitation to address in a future PR. For now, the implementation works correctly for connected graphs.

Files Added

  • 2bitcomp_5.cnf.uai - SAT benchmark
  • CSP_12.uai - CSP benchmark
  • Grids_11.uai, Grids_12.uai - Grid models (100 vars)
  • Pedigree_11.uai - Pedigree benchmark
  • grid10x10.f10.uai - 10x10 grid
  • pdb1akg.uai, pdb1etl.uai, pdb1etn.uai - Protein folding
  • sat-grid-pbl-0010.cnf.uai - SAT grid

Source: UAI Inference Competition, /tmp/uai-competitions/

GiggleLiu and others added 2 commits January 25, 2026 18:27
Add 10 benchmark files from UAI 2008/2014 inference competitions:
- Protein folding: pdb1akg, pdb1etl, pdb1etn
- Grid models: Grids_11, Grids_12, grid10x10.f10
- SAT: 2bitcomp_5.cnf, sat-grid-pbl-0010.cnf
- CSP: CSP_12
- Pedigree: Pedigree_11

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Remove unused imports (Enum, auto) from tropical_einsum.py
- Add validation to tropical_reduce_max for invalid elim_vars
- Add validation to argmax_trace for missing assignment values
- Remove unused variable (ixs) in contraction.py
- Remove dead code (tropical_contract_binary) from primitives.py
- Remove unused imports from test_tropical_einsum.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
GiggleLiu and others added 2 commits January 25, 2026 20:24
- Add n-ary contraction support in DefaultRule (chains as binary ops)
- Consolidate Backpointer: import from primitives.py instead of duplicate
- Add Diag test coverage (test_diag_extraction, test_diag_with_extra_dim)
- Improve test_contract_omeco_tree_matches_legacy to verify correctness
- Add performance comment about sequential batched GEMM processing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Users should use contraction order optimization (omeco) to decompose
n-ary contractions into binary contractions first.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@ChanceSiyuan ChanceSiyuan self-requested a review January 25, 2026 12:36
@ChanceSiyuan ChanceSiyuan merged commit 9bdd0f7 into main Jan 25, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect elimination order extraction from omeco tree causes memory explosion

3 participants