Skip to content

Commit eb7d0b0

Browse files
Merge pull request #33 from learningmatter-mit/ase_dyn_update
Ase dynamics update
2 parents 18f8fed + 4e907a6 commit eb7d0b0

2 files changed

Lines changed: 178 additions & 6 deletions

File tree

nff/md/nvt.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,56 @@
22
import math
33
import os
44
import pickle
5+
import warnings
56
from typing import Optional
67

8+
import ase
79
import numpy as np
810
from ase import units
911
from ase.md.logger import MDLogger
1012
from ase.md.md import MolecularDynamics
1113
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
1214
from ase.optimize.optimize import Dynamics
15+
from packaging.version import Version, parse
1316
from tqdm import tqdm
1417

1518
from nff.io.ase import AtomsBatch
1619

20+
ASE_VERSION = parse(ase.__version__)
21+
ASE_CUTOFF_VERSION = parse("3.23.0")
22+
23+
24+
def run_with_ase_check(
25+
integrator: MolecularDynamics,
26+
steps_per_epoch: int,
27+
ase_ver: Version = ASE_VERSION,
28+
ase_cut: Version = ASE_CUTOFF_VERSION,
29+
) -> None:
30+
"""Run the ASE dynamics with a check for the ASE version. ASE v3.23 has updated
31+
the `run` method in the `Dynamics` class, so we need to check for the version
32+
and run the appropriate method. This function will be deprecated in the future,
33+
as ASE v3.23 will be the minimum version required for nff, and contains a warning
34+
to that effect.
35+
Args:
36+
integrator (MolecularDynamics): ASE integrator object or thermostat like NoseHoover
37+
steps_per_epoch (int): number of steps per epoch
38+
ase_ver (Version): ASE version
39+
ase_cut (Version): ASE cutoff version where Dynamics approach was changed
40+
Raises:
41+
DeprecationWarning: if the ASE version is less than 3.23
42+
"""
43+
if ase_ver < ase_cut:
44+
warnings.warn(
45+
f"ASE version {ase_ver} uses outdated `run` method in"
46+
" its `Dynamics` class. Please update to a newer version of ASE as this"
47+
" method will be deprecated in nff in the future.",
48+
DeprecationWarning,
49+
stacklevel=2,
50+
)
51+
Dynamics.run(integrator)
52+
else:
53+
Dynamics.run(integrator, steps=steps_per_epoch)
54+
1755

1856
class NoseHoover(MolecularDynamics):
1957
def __init__(
@@ -154,7 +192,7 @@ def run(self, steps=None):
154192

155193
for _ in tqdm(range(epochs)):
156194
self.max_steps += steps_per_epoch
157-
Dynamics.run(self)
195+
run_with_ase_check(self, steps_per_epoch)
158196
self.atoms.update_nbr_list()
159197

160198

@@ -382,7 +420,7 @@ def run(self, steps=None):
382420

383421
for _ in tqdm(range(epochs)):
384422
self.max_steps += steps_per_epoch
385-
Dynamics.run(self)
423+
run_with_ase_check(self, steps_per_epoch)
386424

387425
x = self.atoms.get_positions(wrap=True)
388426
self.atoms.set_positions(x)
@@ -567,7 +605,7 @@ def run(self, steps=None):
567605

568606
for _ in tqdm(range(epochs)):
569607
self.max_steps += steps_per_epoch
570-
Dynamics.run(self)
608+
run_with_ase_check(self, steps_per_epoch)
571609
self.atoms.update_nbr_list()
572610

573611
momenta = []
@@ -733,7 +771,7 @@ def run(self, steps=None):
733771

734772
for _ in tqdm(range(epochs)):
735773
self.max_steps += steps_per_epoch
736-
Dynamics.run(self)
774+
run_with_ase_check(self, steps_per_epoch)
737775
self.atoms.update_nbr_list()
738776
Stationary(self.atoms)
739777
ZeroRotation(self.atoms)
@@ -821,7 +859,7 @@ def run(self, steps=None):
821859
# set hydrogen mass to 2 AMU (deuterium, following Grimme's mTD approach)
822860
self.increase_h_mass()
823861

824-
Dynamics.run(self)
862+
run_with_ase_check(self, steps_per_epoch)
825863

826864
# reset the masses
827865
self.decrease_h_mass()
@@ -965,7 +1003,7 @@ def run(self, steps=None):
9651003

9661004
for _ in range(epochs):
9671005
self.max_steps += steps_per_epoch
968-
Dynamics.run(self)
1006+
run_with_ase_check(self, steps_per_epoch)
9691007
self.atoms.update_nbr_list()
9701008

9711009

nff/tests/dynamics_test.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import copy
2+
import os
23
import pickle
34
import random
5+
import unittest as ut
46
from datetime import datetime
7+
from pathlib import Path
58

69
import numpy as np
10+
import pytest
711
import torch
812
from ase.io.trajectory import Trajectory
913
from torch.utils.data import DataLoader
1014

1115
from nff.data import Dataset, collate_dicts
16+
from nff.io.ase import AtomsBatch
17+
from nff.io.ase_calcs import NeuralFF
18+
from nff.md.nvt import Langevin
1219
from nff.md.nvt_ax import NoseHoover, NoseHooverChain
1320
from nff.md.utils_ax import ZhuNakamuraLogger, atoms_to_nxyz, mol_dot, mol_norm
1421
from nff.train import load_model
@@ -18,11 +25,97 @@
1825
HBAR = 1
1926
OUT_FILE = "trj.csv"
2027
LOG_FILE = "trj.log"
28+
this_file = Path(__file__).resolve()
29+
ETHANOL_MODEL_PATH = (
30+
this_file.parent.parent.parent / "tutorials" / "models" / "cco_1" / "best_model"
31+
) # Simon's SchNet model
2132

2233

2334
METHOD_DIC = {"nosehoover": NoseHoover, "nosehooverchain": NoseHooverChain}
2435

2536

37+
def get_directed_ethanol():
38+
"""Returns an ethanol molecule.
39+
40+
Returns:
41+
ethanol (Atoms)
42+
"""
43+
props = {
44+
"nxyz": torch.Tensor(
45+
[
46+
[6.0000e00, 5.5206e-03, 5.9149e-01, -8.1382e-04],
47+
[6.0000e00, -1.2536e00, -2.5536e-01, -2.9801e-02],
48+
[8.0000e00, 1.0878e00, -3.0755e-01, 4.8230e-02],
49+
[1.0000e00, 6.2821e-02, 1.2838e00, -8.4279e-01],
50+
[1.0000e00, 6.0567e-03, 1.2303e00, 8.8535e-01],
51+
[1.0000e00, -2.2182e00, 1.8981e-01, -5.8160e-02],
52+
[1.0000e00, -9.1097e-01, -1.0539e00, -7.8160e-01],
53+
[1.0000e00, -1.1920e00, -7.4248e-01, 9.2197e-01],
54+
[1.0000e00, 1.8488e00, -2.8632e-02, -5.2569e-01],
55+
]
56+
),
57+
"energy": torch.tensor(-4.3701),
58+
"energy_grad": torch.Tensor(
59+
[
60+
[10.2030, -33.6563, 1.9132],
61+
[-59.5878, 42.4086, 10.0746],
62+
[-36.9785, 2.0060, 18.7998],
63+
[-1.8185, 5.6604, 4.6715],
64+
[-1.8685, 0.9660, -1.9927],
65+
[11.0286, -11.6878, 18.4956],
66+
[38.0142, -24.5804, -16.6240],
67+
[5.8505, 15.7041, -12.9981],
68+
[35.1569, 3.1794, -22.3399],
69+
]
70+
),
71+
"smiles": "CCO",
72+
"num_atoms": torch.tensor(9),
73+
"nbr_list": torch.Tensor(
74+
[
75+
[0, 1],
76+
[0, 2],
77+
[0, 3],
78+
[0, 4],
79+
[0, 5],
80+
[0, 6],
81+
[0, 7],
82+
[0, 8],
83+
[1, 2],
84+
[1, 3],
85+
[1, 4],
86+
[1, 5],
87+
[1, 6],
88+
[1, 7],
89+
[1, 8],
90+
[2, 3],
91+
[2, 4],
92+
[2, 5],
93+
[2, 6],
94+
[2, 7],
95+
[2, 8],
96+
[3, 4],
97+
[3, 5],
98+
[3, 6],
99+
[3, 7],
100+
[3, 8],
101+
[4, 5],
102+
[4, 6],
103+
[4, 7],
104+
[4, 8],
105+
[5, 6],
106+
[5, 7],
107+
[5, 8],
108+
[6, 7],
109+
[6, 8],
110+
[7, 8],
111+
]
112+
),
113+
"charge": torch.tensor(0.0),
114+
"spin": torch.tensor(0.0),
115+
}
116+
return AtomsBatch(positions=props["nxyz"][:, 1:], directed=True, numbers=props["nxyz"][:, 0], props=props)
117+
118+
26119
class ZhuNakamuraDynamics(ZhuNakamuraLogger):
27120
"""
28121
Class for running Zhu-Nakamura surface-hopping dynamics. This method follows the description in
@@ -974,3 +1067,44 @@ def run(self):
9741067
)
9751068

9761069
batched_zn.run()
1070+
1071+
1072+
# @pytest.mark.usefixtures("device")
1073+
@pytest.mark.skip("Works locally but need to update to work on remote CI")
1074+
class TestLangevin(ut.TestCase):
1075+
def setUp(self):
1076+
self.ethanol = get_directed_ethanol()
1077+
self.device = self._test_fixture_device
1078+
self.model = NeuralFF.from_file(ETHANOL_MODEL_PATH, device=self.device)
1079+
self.ethanol.set_calculator(self.model)
1080+
if os.path.exists("langevin.traj"):
1081+
os.remove("langevin.traj")
1082+
if os.path.exists("langevin.log"):
1083+
os.remove("langevin.log")
1084+
1085+
@pytest.mark.timeout(30)
1086+
def test_langevin(self):
1087+
# Set up Langevin dynamics
1088+
my_dt = 1.0 # fs
1089+
my_temp = 100 # K
1090+
my_friction = 1.0
1091+
logfile = "langevin.log"
1092+
1093+
dyn = Langevin(
1094+
self.ethanol,
1095+
timestep=my_dt,
1096+
temperature=my_temp,
1097+
friction=my_friction,
1098+
maxwell_temp=my_temp,
1099+
logfile=logfile,
1100+
trajectory="langevin.traj",
1101+
)
1102+
dyn.run(steps=40)
1103+
1104+
# Check that the trajectory file was created
1105+
assert os.path.exists("langevin.traj")
1106+
assert os.path.exists("langevin.log")
1107+
1108+
1109+
if __name__ == "__main__":
1110+
ut.main()

0 commit comments

Comments
 (0)