-
Notifications
You must be signed in to change notification settings - Fork 79
Accelerated ts.static + added scaling scripts #427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torch_sim/autobatching.py
Outdated
| bbox[i] += 2.0 | ||
| volume = bbox.prod() / 1000 # convert A^3 to nm^3 | ||
| number_density = state.n_atoms / volume.item() | ||
| # Use cell volume (O(1)); SimState always has a cell. Avoids O(N) position scan. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-periodic systems don't have a sensible cell, see #412
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now minimized the differences compared to the initial code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition, I added explicit tests for the memory scaler values and verified that the changes in this PR do not affect the test’s success
torch_sim/autobatching.py
Outdated
| self.memory_scalers = calculate_batched_memory_scalers( | ||
| states, self.memory_scales_with | ||
| ) | ||
| self.state_slices = states.split() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batching makes sense here
torch_sim/autobatching.py
Outdated
| if isinstance(states, SimState): | ||
| self.batched_states = [[states[index_bin]] for index_bin in self.index_bins] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state.split() is identical to this and faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reusing self.state_slices instead of calling states.split() again makes the code 5% faster, so I'd keep it
3138aed to
e91fe92
Compare
torch_sim/autobatching.py
Outdated
| ) | ||
| self.state_slices = states.split() | ||
| else: | ||
| self.state_slices = states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not concat and then called the batched logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the if branch, the input is already a single batched SimState, so we call calculate_batched_memory_scalers and then split() once to get state_slices. No concatenation is needed.
In the else branch, the input is a list of states, so we keep state_slices = states and compute scalers per state. We avoid concatenating and using the batched path, since that would require a concat followed by a split(), resulting in extra passes and higher peak memory for the same outcome.
torch_sim/state.py
Outdated
|
|
||
| def split(self) -> list[Self]: | ||
| """Split the SimState into a list of single-system SimStates. | ||
| def split(self) -> Sequence[Self]: # noqa: C901 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this whole code block looks hard to understand/maintain. why did the external _split_state functional programming method need to be removed?
This looks like it just creates an efficient slicing iter, can we break those parts out more cleanly in the same functional pattern as before? i.e. _get_system_slice(sim_state: SimState, i: int) as a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I got rid of the intermediate class and made the code more functional
a260cc0 to
2ba34e0
Compare
|
I’ve implemented the revisions and added an additional optimization to the state creation. Please let me know if you need any further edits. |
|
I will give another review in the next few days, thanks @Fallett! |
00da901 to
da23a55
Compare
149af42 to
413b9cd
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
413b9cd to
698044a
Compare
|
Thank you @orionarcher! I made further edits to follow the original logic more closely. Among the changes, I generalized the existing |
Co-authored-by: Cursor <cursoragent@cursor.com>
0302a94 to
9bf6a56
Compare
orionarcher
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @falletta! I am happy to have the batched memory scaler calculations and associated speedup!
A few thoughts:
- The batched memory scaler calculation seems quite reasonable and is a great addition to this PR.
- The changes to state manipulation don't actually seem to save any time, they just shift from eager to lazy evaluation. In the absense of concrete benchmarking data about those changes specifically, I'd favor removing the changes.
- There are several changes that seem arbitrary and unrelated to the main purpose of the PR. I generally support vibe coded contributions but it shouldn't be just the reviewers responsibility to catch random unrelated changes. For example, you introduce a breaking change to the API of
initialize_statesfor no apparent reason. This seems to get past the tests but could easily disrupt downstream behavior.
examples/scaling/scaling_nve.py
Outdated
| @@ -0,0 +1,76 @@ | |||
| """Scaling for TorchSim NVE.""" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These will all run in CI, so maybe we could put them all in a single script instead of four to speed up testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, these tests should not run in CI given the changes in .github/workflows/test.yml. In their current form, these are not really intended as tests (there are no assert statements), but rather as performance checks. I think keeping them separate makes it easier to isolate the scaling behavior we want to examine for potential optimizations.
That said, we could turn these into actual tests and include them in the tests repository, for example by checking that the timings stay below the current thresholds. Would you prefer that approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do worry about these decaying over time if they aren't tested. The reason to run all scripts in CI is because it makes sure they don't break over time. We shouldn't include them as tests but they should be runnable scripts like the other scripts we have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, good point. I merged all the tests into examples/scripts/8.scaling.py and set the maximum number of structures to relatively small values so the test runs quickly.
torch_sim/autobatching.py
Outdated
| vol = torch.abs(state.volume) / 1000 # A^3 -> nm^3 | ||
| return torch.where(vol > 0, n * n / vol, n).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the variable volume elsewhere but vol here?
would also change n -> n_atoms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure!
torch_sim/state.py
Outdated
| def _split_state[T: SimState](state: T) -> Sequence[T]: | ||
| """Return a lazy Sequence view of state split into single-system states. | ||
|
|
||
| Each single-system state is created on first access, so the call is O(1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I don't really follow how this is faster. It just shifts the cost to later down the line by making the evaluation lazy instead of eager, which isn't necessarily better given the code is harder to follow. Could you explain how this is saving time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The eager implementation creates all N SimState objects upfront by running torch.split() across every attribute, which is O(N) work regardless of how many states are actually used. The lazy implementation only computes a cumsum (effectively O(1)) and defers creating states until __getitem__ is called.
This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.
Benchmarking with n=10000 systems shows the lazy version finishes in 9.89 s vs 11.37 s for eager (~15% faster). If all states are accessed, the total work is basically the same, but the common autobatching paths don’t require all states to be materialized at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This matters because estimate_max_memory_scaler only accesses 2 states (argmin/argmax) out of potentially hundreds—so lazy evaluation builds 2 states, while eager still builds all N.
I don't follow? How can you get the argmin over values that haven't been evaluated? Don't you need to evaluate to know the minimum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you point me to where the eager and lazy are benchmarked against eachother?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please find the benchmark script below. The script defines an OldState(SimState) class using the old _split_state method and produces the results shown. Alternatively, you can run 8.scaling.py with the current version of the code and with the split method reverted, and you’ll get the same results.
=== Comparison ===
n= 1: eager=2.387301s, lazy=0.026688s, speedup=89.45x
n= 1: eager=0.275248s, lazy=0.022556s, speedup=12.20x
n= 1: eager=0.677568s, lazy=0.022692s, speedup=29.86x
n= 1: eager=0.023758s, lazy=0.021832s, speedup=1.09x
n= 10: eager=0.283295s, lazy=0.025066s, speedup=11.30x
n= 100: eager=0.700301s, lazy=0.108458s, speedup=6.46x
n= 250: eager=1.020541s, lazy=0.252591s, speedup=4.04x
n= 500: eager=0.572549s, lazy=0.506986s, speedup=1.13x
n=1000: eager=1.115220s, lazy=0.989014s, speedup=1.13x
n=1500: eager=1.665096s, lazy=1.469927s, speedup=1.13x
n=5000: eager=5.755523s, lazy=5.062751s, speedup=1.14x
n=10000: eager=10.948717s, lazy=9.643398s, speedup=1.14x
Benchmark Script
"""Test comparing OldState (eager split) vs State (lazy split) performance."""
import time
import typing
from typing import Self
from unittest.mock import patch
import torch
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp
import torch_sim as ts
from torch_sim.models.mace import MaceModel, MaceUrls
from torch_sim.state import SimState, get_attrs_for_scope
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float64
N_STRUCTURES = [1, 1, 1, 1, 10, 100, 250, 500, 1000, 1500, 5000, 10000]
MAX_MEMORY_SCALER = 400_000
MEMORY_SCALES_WITH = "n_atoms_x_density"
class OldState(SimState):
"""Old state representation that uses eager splitting.
This class inherits from SimState but overrides split() to use the old eager
approach where all sub-states are created upfront when split() is called.
"""
def split(self) -> list[Self]:
"""Split the OldState into a list of single-system OldStates (EAGER).
This is the OLD approach that creates ALL states upfront, which is O(n)
where n is the number of systems.
Returns:
list[OldState]: A list of OldState objects, one per system
"""
return self._split_state(self)
@staticmethod
def _split_state[T: SimState](state: T) -> list[T]:
"""Split a SimState into a list of states, each containing a single system.
Divides a multi-system state into individual single-system states, preserving
appropriate properties for each system.
Args:
state (SimState): The SimState to split
Returns:
list[SimState]: A list of SimState objects, each containing a single
system
"""
system_sizes = state.n_atoms_per_system.tolist()
split_per_atom = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"):
if attr_name != "system_idx":
split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)
split_per_system = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-system"):
if isinstance(attr_value, torch.Tensor):
split_per_system[attr_name] = torch.split(attr_value, 1, dim=0)
else: # Non-tensor attributes are replicated for each split
split_per_system[attr_name] = [attr_value] * state.n_systems
global_attrs = dict(get_attrs_for_scope(state, "global"))
# Create a state for each system
states: list[T] = []
n_systems = len(system_sizes)
zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64)
cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0)))
for sys_idx in range(n_systems):
# Build per-system attributes (padded attributes stay padded for consistency)
per_system_dict = {
attr_name: split_per_system[attr_name][sys_idx]
for attr_name in split_per_system
}
system_attrs = {
# Create a system tensor with all zeros for this system
"system_idx": torch.zeros(
system_sizes[sys_idx], device=state.device, dtype=torch.int64
),
# Add the split per-atom attributes
**{
attr_name: split_per_atom[attr_name][sys_idx]
for attr_name in split_per_atom
},
# Add the split per-system attributes (with unpadding applied)
**per_system_dict,
# Add the global attributes
**global_attrs,
}
atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1])
new_constraints = [
new_constraint
for constraint in state.constraints
if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx))
]
system_attrs["_constraints"] = new_constraints
states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type]
return states
@classmethod
def from_sim_state(cls, state: SimState) -> "OldState":
"""Create OldState from a SimState."""
return cls(
positions=state.positions.clone(),
masses=state.masses.clone(),
cell=state.cell.clone(),
pbc=state.pbc.clone() if isinstance(state.pbc, torch.Tensor) else state.pbc,
atomic_numbers=state.atomic_numbers.clone(),
charge=state.charge.clone() if state.charge is not None else None,
spin=state.spin.clone() if state.spin is not None else None,
system_idx=state.system_idx.clone() if state.system_idx is not None else None,
_constraints=state.constraints.copy(),
)
def run_torchsim_static(
n_structures_list: list[int],
base_structure,
model,
device: torch.device,
) -> list[float]:
"""Run static calculations for each n using batched path, return timings."""
autobatcher = ts.BinningAutoBatcher(
model=model,
max_memory_scaler=MAX_MEMORY_SCALER,
memory_scales_with=MEMORY_SCALES_WITH,
)
times: list[float] = []
for n in n_structures_list:
structures = [base_structure] * n
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
ts.static(structures, model, autobatcher=autobatcher)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
times.append(elapsed)
print(f" n={n} static_time={elapsed:.6f}s")
return times
if __name__ == "__main__":
# Setup
mgo_atoms = bulk("MgO", crystalstructure="rocksalt", a=4.21, cubic=True)
print("Loading MACE model...")
loaded_model = mace_mp(
model=MaceUrls.mace_mpa_medium,
return_raw_model=True,
default_dtype="float64",
device=str(DEVICE),
)
mace_model = MaceModel(
model=typing.cast("torch.nn.Module", loaded_model),
device=DEVICE,
compute_forces=True,
compute_stress=True,
dtype=DTYPE,
enable_cueq=False,
)
print("\n=== Static Benchmark Comparison: Eager vs Lazy Split ===\n")
# Run with eager split (OldState behavior)
print("Running with eager split (OldState):")
with patch.object(SimState, "split", lambda self: OldState._split_state(self)):
eager_times = run_torchsim_static(
N_STRUCTURES, mgo_atoms, mace_model, DEVICE
)
# Run with lazy split (default State behavior)
print("\nRunning with lazy split (State):")
lazy_times = run_torchsim_static(
N_STRUCTURES, mgo_atoms, mace_model, DEVICE
)
# Print comparison
print("\n=== Comparison ===")
for i, n in enumerate(N_STRUCTURES):
speedup = eager_times[i] / max(lazy_times[i], 1e-9)
print(
f" n={n:4d}: eager={eager_times[i]:.6f}s, "
f"lazy={lazy_times[i]:.6f}s, speedup={speedup:.2f}x"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argmin runs over metric_values, which is just a list of floats computed from the batched state tensors before we split anything. This gives us integer indices, and only when we call state_list[idx]we do actually create the 2 needed states.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed benchmark and explanation, that speedup makes sense then!
That said, the type("SplitSeq", (Sequence,), ...) pattern for creating a runtime class is pretty unusual and hard to follow. Could we get the same lazy benefit with a simpler design?
CC suggestion: A _LazySplitView class with explicit len and getitem would be much more readable and debuggable while preserving the lazy semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I’m now using a LazySplitView class, and I’ve confirmed that performance is preserved.
torch_sim/state.py
Outdated
| all relevant properties. | ||
| their requested order (not natural 0,1,2 order). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That refers to the fact that calling _slice_state(state, [3, 1, 4]) returns a new state where the systems appear in the exact order requested (3→0, 1→1, 4→2), instead of being sorted in ascending order. I updated the docstring to better reflect this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I prefer this behaviour and I am happy to leave this in this PR but I am noting that 1) this is a breaking change and 2) this is somewhat unrelated to the other changes in this PR.
I am a bit surprised that the old tests didn't catch this. Could you write a test or two that illustrates the difference and that fails with the old behavior and passes with the new?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea. I added two dedicated tests to test_state.py, namely test_slice_state_permuted_order and test_slice_state_reversed_subset. They pass with the new _slice_state implementation and fail with the old one. Below is a standalone test script that confirms the behavior:
"""Test that the new _slice_state preserves system ordering while the old
mask-based implementation does not.
"""
import pytest
import torch
import torch_sim as ts
from torch_sim.state import (
SimState,
_filter_attrs_by_mask,
_slice_state,
)
def _old_slice_state(state: SimState, system_indices):
"""Old _slice_state that uses boolean masks — loses ordering."""
if isinstance(system_indices, list):
system_indices = torch.tensor(
system_indices, device=state.device, dtype=torch.int64
)
if len(system_indices) == 0:
raise ValueError("system_indices cannot be empty")
system_range = torch.arange(state.n_systems, device=state.device)
system_mask = torch.isin(system_range, system_indices)
atom_mask = torch.isin(state.system_idx, system_indices)
filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask)
return type(state)(**filtered_attrs)
@pytest.fixture
def _three_system_state(
si_sim_state: SimState,
ar_supercell_sim_state: SimState,
fe_supercell_sim_state: SimState,
) -> tuple[SimState, SimState, SimState, SimState]:
"""Concatenate Si, Ar, Fe into a 3-system state."""
concatenated = ts.concatenate_states(
[si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state]
)
return concatenated, si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state
# --- New _slice_state: these pass ---
def test_permuted_order(_three_system_state) -> None:
"""Slice with [2, 0, 1] should yield [Fe, Si, Ar]."""
concatenated, si, ar, fe = _three_system_state
result = _slice_state(concatenated, [2, 0, 1])
assert torch.allclose(result[0].positions, fe.positions)
assert torch.allclose(result[1].positions, si.positions)
assert torch.allclose(result[2].positions, ar.positions)
assert torch.allclose(result.cell[0], fe.cell[0])
assert torch.allclose(result.cell[1], si.cell[0])
assert torch.allclose(result.cell[2], ar.cell[0])
def test_reversed_subset(_three_system_state) -> None:
"""Slice with [2, 0] should match concatenate([Fe, Si])."""
concatenated, si, _ar, fe = _three_system_state
result = _slice_state(concatenated, [2, 0])
expected = ts.concatenate_states([fe, si])
assert torch.allclose(result.positions, expected.positions)
assert torch.allclose(result.cell, expected.cell)
assert torch.allclose(result.atomic_numbers, expected.atomic_numbers)
assert torch.allclose(result.masses, expected.masses)
assert result.n_systems == expected.n_systems
# --- Old _slice_state: identical tests, expected to fail ---
@pytest.mark.xfail(
raises=(AssertionError, RuntimeError),
reason="Old mask-based _slice_state does not preserve ordering",
strict=True,
)
def test_old_permuted_order(_three_system_state) -> None:
"""Slice with [2, 0, 1] should yield [Fe, Si, Ar]."""
concatenated, si, ar, fe = _three_system_state
result = _old_slice_state(concatenated, [2, 0, 1])
assert torch.allclose(result[0].positions, fe.positions)
assert torch.allclose(result[1].positions, si.positions)
assert torch.allclose(result[2].positions, ar.positions)
assert torch.allclose(result.cell[0], fe.cell[0])
assert torch.allclose(result.cell[1], si.cell[0])
assert torch.allclose(result.cell[2], ar.cell[0])
@pytest.mark.xfail(
raises=(AssertionError, RuntimeError),
reason="Old mask-based _slice_state does not preserve ordering",
strict=True,
)
def test_old_reversed_subset(_three_system_state) -> None:
"""Slice with [2, 0] should match concatenate([Fe, Si])."""
concatenated, si, _ar, fe = _three_system_state
result = _old_slice_state(concatenated, [2, 0])
expected = ts.concatenate_states([fe, si])
assert torch.allclose(result.positions, expected.positions)
assert torch.allclose(result.cell, expected.cell)
assert torch.allclose(result.atomic_numbers, expected.atomic_numbers)
assert torch.allclose(result.masses, expected.masses)
assert result.n_systems == expected.n_systems
torch_sim/state.py
Outdated
| device: torch.device | None = None, | ||
| dtype: torch.dtype | None = None, | ||
| device: torch.device, | ||
| dtype: torch.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a minor fix, but I think making device and dtype required improves type checking and forces callers to be explicit. With None defaults, .to(None, None) is a no-op, so the state silently stays where it was, potentially causing device mismatches with the model later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not introduce a breaking API change here. If you feel strongly on this let's do it elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I reverted it back to a less strict type check.
|
Thanks Orion for reviewing. I applied a few cosmetic changes to address your comments. Regarding your points:
|
0846289 to
3859e00
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
33f760e to
9e2f13e
Compare
|
I better understand the optimization now, thanks for the clarifications.
I don't see how this is true for the |
|
Also maybe rename 8_scaling.py folder to 8_benchmarking.py? EDIT: Sorry, didn't mean to close! Misclick |
|
Thanks for reviewing, @orionarcher! I’ve applied a few additional changes to |
Summary
Changes
Results
The figure below shows the speedup achieved for static evaluations, 10-step atomic relaxation, 10-step NVE MD, and 10-step NVT MD. The test is performed for a 8-atom cubic supercell of MgO using the
mace-mpamodel. Prior results are shown in blue, while new results are shown in red. The speedup is calculated asspeedup (%) = (baseline_time / current_time − 1) × 100. We observe that:ts.staticachieves a 52.6% speedup for 100,000 structurests.relaxachieves a 4.8% speedup for 1,500 structurests.integrate(NVE) achieves a 0.9% speedup for 10,000 structurests.integrate(NVT) achieves a 1.4% speedup for 10,000 structuresProfiling
The figure below shows a detailed performance profile. Additional optimization can be achieved by disabling the trajectory reporter when not needed, which will be addressed in a separate PR.
Comments
From the scaling plots, we can see that the timings of
ts.staticandts.integrateare all consistent with each other. Indeed:ts.static→ 85s for 100'000 evaluationsts.integrateNVE → 87s for 10'000 structures (10 MD steps each) → 87s for 100'000 evaluationsts.integrateNVT → 89s for 10'000 structures (10 MD steps each) → 89s for 100'000 evaluationsHowever, when looking at the relaxation:
ts.relax→ 63s for 1'000 structures (10 relax steps each) → 63s for 10'000 evaluations → ~630s for 100'000 evaluationsSo
ts.relaxis about 7x slower thants.staticorts.integrate. The unbatched FrechetCellFilter clearly contributes to that, and will be the focus on a separate PR.