perf: vmap over molecules for batched Hessian computation (#180)#189
Merged
ericchansen merged 2 commits intomasterfrom Mar 30, 2026
Merged
perf: vmap over molecules for batched Hessian computation (#180)#189ericchansen merged 2 commits intomasterfrom
ericchansen merged 2 commits intomasterfrom
Conversation
Add jax.vmap-based batched Hessian evaluation for molecules sharing the same topology (bond/angle connectivity). This replaces sequential Python-loop Hessian computation with a single vectorized call per topology group, significantly reducing evaluation time for multi- conformer workflows (GS + TS of the same molecule). New module: q2mm/backends/mm/batched.py - TopologyGroup dataclass for grouping compatible molecules - group_by_topology() groups molecules by connectivity signature - batched_hessians() computes all Hessians via jax.vmap - batched_frequencies() wraps batched Hessians → frequencies JaxHandle: add _batched_coord_hess_fn cached callable field ObjectiveFunction integration: - _can_batch_hessians() detects when batching is beneficial - _precompute_batched_hessians() pre-computes Hessians per group - _compute_residuals() uses batched path when available - _evaluate_molecule() accepts precomputed_hessian kwarg - Graceful fallback to sequential on any failure The batched path is fully backward-compatible: it activates only when the engine is JaxEngine, there are 2+ molecules, and references include Hessian-derived data (frequencies or eigenmatrix). Sequential evaluation remains the default for all other cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces batched Hessian evaluation for the JAX MM backend by grouping molecules with identical topology and using jax.vmap to compute multiple coordinate Hessians per group, reducing Python-loop overhead in ObjectiveFunction.
Changes:
- Add new
q2mm/backends/mm/batched.pymodule with topology grouping and vmapped Hessian/frequency helpers. - Extend
JaxHandlewith a cached_batched_coord_hess_fnfor vmapped Hessian computation. - Add an ObjectiveFunction batched Hessian precompute path with sequential fallback, plus new JAX-marked tests.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
q2mm/backends/mm/batched.py |
New topology-signature/grouping logic and vmapped Hessian/frequency computation. |
q2mm/backends/mm/jax_engine.py |
Adds cached field for vmapped Hessian function on JaxHandle. |
q2mm/optimizers/objective.py |
Adds batched Hessian eligibility + precompute, and uses precomputed Hessians in per-molecule evaluation. |
test/test_batched_eval.py |
Adds JAX tests for topology grouping and parity vs sequential paths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Include torsion/vdw indices, n_atoms, functional_form in topology signature - Use SHA-256 hash instead of giant string key - Cache handles from group_by_topology to avoid duplicate compilation - Change batching fallback log from debug to warning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Enable
jax.vmapover molecules for batched Hessian computation, replacing the sequential Python loop in the objective function. Closes #180.Changes
q2mm/backends/mm/batched.pyTopologyGroup,group_by_topology(),batched_hessians()(vmap),batched_frequencies()q2mm/backends/mm/jax_engine.py_batched_coord_hess_fnfield onJaxHandlefor caching vmapped Hessianq2mm/optimizers/objective.py_can_batch_hessians(),_precompute_batched_hessians(), batched path in_compute_residuals()with fallbackHow It Works
jax.vmap(jax.hessian(energy_fn))computes all Hessians in one vectorized callTesting