Made trajectory reporter optional in static function#441
Made trajectory reporter optional in static function#441falletta wants to merge 8 commits intoTorchSim:mainfrom
Conversation
torch_sim/runners.py
Outdated
| else: | ||
| # Collect base properties for each system when no reporter | ||
| for sys_idx in range(sub_state.n_systems): | ||
| atom_mask = sub_state.system_idx == sys_idx | ||
| base_props: dict[str, torch.Tensor] = { | ||
| "potential_energy": static_state.energy[sys_idx], | ||
| } | ||
| if model.compute_forces: | ||
| base_props["forces"] = static_state.forces[atom_mask] | ||
| if model.compute_stress: | ||
| base_props["stress"] = static_state.stress[sys_idx] | ||
| all_props.append(base_props) |
There was a problem hiding this comment.
Are you sure this is faster?
The profiling data doesn't really say anything about how long the report call takes, and this isn't benchmarked against the current code.
There was a problem hiding this comment.
Yes, I confirm the speedup. Below are the results from running the run_torchsim_static function in 8.scaling.py (see the static PR) before and after the fix. We observe a speedup of up to 24.3%, and it continues to increase with system size. In addition, I verified that the cost associated to the trajectory reporter disappears from the profiling analysis.
Previous results:
=== Static benchmark ===
n=1 static_time=1.928943s
n=1 static_time=0.272846s
n=1 static_time=0.683335s
n=1 static_time=0.026675s
n=10 static_time=0.281990s
n=100 static_time=0.705871s
n=500 static_time=1.510273s
n=1000 static_time=1.528872s
n=2500 static_time=3.809000s
n=5000 static_time=7.890238s
New results:
n=1 static_time=2.165601s
n=1 static_time=0.271961s
n=1 static_time=0.665016s
n=1 static_time=0.022899s
n=10 static_time=0.295468s
n=100 static_time=0.692651s
n=500 static_time=1.411905s
n=1000 static_time=1.291772s -> 18.3% speedup
n=2500 static_time=3.175455s -> 19.9% speedup
n=5000 static_time=6.348887s -> 24.3% speedup
There was a problem hiding this comment.
I'm confused, is the 24% speedup here coming from just from removing the report call or all of the changes in the other PR too? I'm curious of the effect just the report call has, I just want to make sure that specific optimizations are well-founded, I'm not opposed to making this change.
|
Upon further review, it seems the slowdown from calling When def report(self, state, step, model=None):
if self.filenames is None:
return self._extract_props_batched(state, step, model)
# ... existing split-based logic unchanged ...
def _extract_props_batched(self, state, step, model):
sizes = state.n_atoms_per_system.tolist()
n_sys = state.n_systems
all_props: list[dict[str, torch.Tensor]] = [{} for _ in range(n_sys)]
for frequency, calculators in self.prop_calculators.items():
if frequency == 0:
continue
for prop_name, prop_fn in calculators.items():
result = prop_fn(state, model)
if result.dim() == 0:
result = result.unsqueeze(0)
# infer per-atom vs per-system from shape
if result.shape[0] == state.n_atoms:
splits = torch.split(result, sizes)
else:
splits = [result[i] for i in range(n_sys)]
for idx in range(n_sys):
sys_step = step[idx] if isinstance(step, list) else step
if sys_step % frequency == 0:
all_props[idx][prop_name] = splits[idx]
return all_props |
By making the trajectory reporter optional, we avoid additional computational overhead (see profiling plot in this PR).