Skip to content

perf: vmap over molecules for batched Hessian computation (#180)#189

Merged
ericchansen merged 2 commits intomasterfrom
perf/issue-180-vmap-molecules
Mar 30, 2026
Merged

perf: vmap over molecules for batched Hessian computation (#180)#189
ericchansen merged 2 commits intomasterfrom
perf/issue-180-vmap-molecules

Conversation

@ericchansen
Copy link
Copy Markdown
Owner

Summary

Enable jax.vmap over molecules for batched Hessian computation, replacing the sequential Python loop in the objective function. Closes #180.

Changes

File Change
q2mm/backends/mm/batched.py New module: TopologyGroup, group_by_topology(), batched_hessians() (vmap), batched_frequencies()
q2mm/backends/mm/jax_engine.py _batched_coord_hess_fn field on JaxHandle for caching vmapped Hessian
q2mm/optimizers/objective.py _can_batch_hessians(), _precompute_batched_hessians(), batched path in _compute_residuals() with fallback

How It Works

  1. Topology grouping: Molecules sharing identical connectivity (bonds, angles, torsions, vdW pairs) are grouped together
  2. Coordinate batching: Within each group, coordinates are stacked into a batch array
  3. vmap Hessian: jax.vmap(jax.hessian(energy_fn)) computes all Hessians in one vectorized call
  4. Transparent fallback: Non-JAX engines or single-molecule groups use the existing sequential path

Testing

  • 14 new tests covering topology signatures, grouping, Hessian/frequency parity vs sequential, ObjectiveFunction integration, and fallback behavior
  • 766 total tests pass with no regressions

Add jax.vmap-based batched Hessian evaluation for molecules sharing
the same topology (bond/angle connectivity). This replaces sequential
Python-loop Hessian computation with a single vectorized call per
topology group, significantly reducing evaluation time for multi-
conformer workflows (GS + TS of the same molecule).

New module: q2mm/backends/mm/batched.py
- TopologyGroup dataclass for grouping compatible molecules
- group_by_topology() groups molecules by connectivity signature
- batched_hessians() computes all Hessians via jax.vmap
- batched_frequencies() wraps batched Hessians → frequencies

JaxHandle: add _batched_coord_hess_fn cached callable field

ObjectiveFunction integration:
- _can_batch_hessians() detects when batching is beneficial
- _precompute_batched_hessians() pre-computes Hessians per group
- _compute_residuals() uses batched path when available
- _evaluate_molecule() accepts precomputed_hessian kwarg
- Graceful fallback to sequential on any failure

The batched path is fully backward-compatible: it activates only when
the engine is JaxEngine, there are 2+ molecules, and references
include Hessian-derived data (frequencies or eigenmatrix). Sequential
evaluation remains the default for all other cases.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings March 29, 2026 19:34
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces batched Hessian evaluation for the JAX MM backend by grouping molecules with identical topology and using jax.vmap to compute multiple coordinate Hessians per group, reducing Python-loop overhead in ObjectiveFunction.

Changes:

  • Add new q2mm/backends/mm/batched.py module with topology grouping and vmapped Hessian/frequency helpers.
  • Extend JaxHandle with a cached _batched_coord_hess_fn for vmapped Hessian computation.
  • Add an ObjectiveFunction batched Hessian precompute path with sequential fallback, plus new JAX-marked tests.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
q2mm/backends/mm/batched.py New topology-signature/grouping logic and vmapped Hessian/frequency computation.
q2mm/backends/mm/jax_engine.py Adds cached field for vmapped Hessian function on JaxHandle.
q2mm/optimizers/objective.py Adds batched Hessian eligibility + precompute, and uses precomputed Hessians in per-molecule evaluation.
test/test_batched_eval.py Adds JAX tests for topology grouping and parity vs sequential paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread q2mm/backends/mm/batched.py Outdated
Comment thread q2mm/optimizers/objective.py
Comment thread q2mm/optimizers/objective.py Outdated
- Include torsion/vdw indices, n_atoms, functional_form in topology signature
- Use SHA-256 hash instead of giant string key
- Cache handles from group_by_topology to avoid duplicate compilation
- Change batching fallback log from debug to warning

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@ericchansen ericchansen merged commit d260608 into master Mar 30, 2026
11 checks passed
@ericchansen ericchansen deleted the perf/issue-180-vmap-molecules branch March 30, 2026 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf: vmap over molecules in objective function

2 participants