Skip to content

Commit 550939f

Browse files
authored
Allow caching of restart data in ASE calculator (#18)
1 parent b58c180 commit 550939f

File tree

3 files changed

+110
-37
lines changed

3 files changed

+110
-37
lines changed

xtb/ase/calculator.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
electronic_temperature 300.0 Electronic temperatur for TB methods
5252
max_iterations 250 Iterations for self-consistent evaluation
5353
solvent "none" GBSA implicit solvent model
54+
cache_api True Reuse generate API objects (recommended)
5455
======================== ============ ============================================
5556
"""
5657

@@ -84,6 +85,7 @@ class XTB(ase_calc.Calculator):
8485
"max_iterations": 250,
8586
"electronic_temperature": 300.0,
8687
"solvent": "None",
88+
"cache_api": True,
8789
}
8890

8991
_res = None
@@ -96,49 +98,85 @@ def __init__(
9698

9799
ase_calc.Calculator.__init__(self, atoms=atoms, **kwargs)
98100

99-
# loads the default parameters and updates with actual values
100-
self.parameters = self.get_default_parameters()
101-
# now set all parameters
102-
self.set(**kwargs)
103-
104-
def set(self, **kwargs):
101+
def set(self, **kwargs) -> dict:
105102
"""Set new parameters to xtb"""
106103

107104
changed_parameters = ase_calc.Calculator.set(self, **kwargs)
108105

109-
self._check(changed_parameters)
106+
self._check_parameters(changed_parameters)
110107

111-
# Always reset the xtb calculator for now
108+
# Always reset the calculation if parameters change
112109
if changed_parameters:
113110
self.reset()
114111

112+
# If the method is changed, invalidate the cached calculator as well
113+
if "method" in changed_parameters:
114+
self._xtb = None
115+
self._res = None
116+
117+
# Minor changes can be updated in the API calculator directly
118+
if self._xtb is not None:
119+
if "accuracy" in changed_parameters:
120+
self._xtb.set_accuracy(self.parameters.accuracy)
121+
122+
if "electronic_temperature" in changed_parameters:
123+
self._xtb.set_electronic_temperature(
124+
self.parameters.electronic_temperature
125+
)
126+
127+
if "max_iterations" in changed_parameters:
128+
self._xtb.set_max_iterations(self.parameters.max_iterations)
129+
130+
if "solvent" in changed_parameters:
131+
self._xtb.set_solvent(get_solvent(self.parameters.solvent))
132+
115133
return changed_parameters
116134

117-
def _check(self, parameters):
135+
def _check_parameters(self, parameters: dict) -> None:
118136
"""Verifiy provided parameters are valid"""
119137

120138
if "method" in parameters and get_method(parameters["method"]) is None:
121139
raise ase_calc.InputError(
122140
"Invalid method {} provided".format(parameters["method"])
123141
)
124142

125-
def reset(self):
143+
def reset(self) -> None:
126144
"""Clear all information from old calculation"""
127145
ase_calc.Calculator.reset(self)
128146

129-
self._res = None
130-
131-
def calculate(
132-
self,
133-
atoms: Optional[Atoms] = None,
134-
properties: List[str] = None,
135-
system_changes: List[str] = ase_calc.all_changes,
136-
):
137-
"""Perform actual calculation with by calling the xtb API"""
138-
139-
if not properties:
140-
properties = ["energy"]
141-
ase_calc.Calculator.calculate(self, atoms, properties, system_changes)
147+
if not self.parameters.cache_api:
148+
self._xtb = None
149+
self._res = None
150+
151+
def _check_api_calculator(self, system_changes: List[str]) -> None:
152+
"""Check state of API calculator and reset if necessary"""
153+
154+
# Changes in positions and cell parameters can use a normal update
155+
_reset = system_changes.copy()
156+
if "positions" in _reset:
157+
_reset.remove("positions")
158+
if "cell" in _reset:
159+
_reset.remove("cell")
160+
161+
# Invalidate cached calculator and results object
162+
if _reset:
163+
self._xtb = None
164+
self._res = None
165+
else:
166+
if system_changes and self._xtb is not None:
167+
try:
168+
_cell = self.atoms.cell
169+
self._xtb.update(
170+
self.atoms.positions / Bohr, _cell / Bohr,
171+
)
172+
# An exception in this part means the geometry is bad,
173+
# still we will give a complete reset a try as well
174+
except XTBException:
175+
self._xtb = None
176+
self._res = None
177+
178+
def _create_api_calculator(self) -> Calculator:
179+
"""Create a new API calculator object"""
142180

143181
_method = get_method(self.parameters.method)
144182
if _method is None:
@@ -152,7 +190,7 @@ def calculate(
152190
_charge = self.atoms.get_initial_charges().sum()
153191
_uhf = int(self.atoms.get_initial_magnetic_moments().sum().round())
154192

155-
self._xtb = Calculator(
193+
calc = Calculator(
156194
_method,
157195
self.atoms.numbers,
158196
self.atoms.positions / Bohr,
@@ -161,15 +199,34 @@ def calculate(
161199
_cell / Bohr,
162200
_periodic,
163201
)
164-
self._xtb.set_verbosity(VERBOSITY_MUTED)
165-
self._xtb.set_accuracy(self.parameters.accuracy)
166-
self._xtb.set_electronic_temperature(self.parameters.electronic_temperature)
167-
self._xtb.set_max_iterations(self.parameters.max_iterations)
168-
self._xtb.set_solvent(get_solvent(self.parameters.solvent))
202+
calc.set_verbosity(VERBOSITY_MUTED)
203+
calc.set_accuracy(self.parameters.accuracy)
204+
calc.set_electronic_temperature(self.parameters.electronic_temperature)
205+
calc.set_max_iterations(self.parameters.max_iterations)
206+
calc.set_solvent(get_solvent(self.parameters.solvent))
169207

170208
except XTBException:
171209
raise ase_calc.InputError("Cannot construct calculator for xtb")
172210

211+
return calc
212+
213+
def calculate(
214+
self,
215+
atoms: Optional[Atoms] = None,
216+
properties: List[str] = None,
217+
system_changes: List[str] = ase_calc.all_changes,
218+
) -> None:
219+
"""Perform actual calculation with by calling the xtb API"""
220+
221+
if not properties:
222+
properties = ["energy"]
223+
ase_calc.Calculator.calculate(self, atoms, properties, system_changes)
224+
225+
self._check_api_calculator(system_changes)
226+
227+
if self._xtb is None:
228+
self._xtb = self._create_api_calculator()
229+
173230
try:
174231
self._res = self._xtb.singlepoint(self._res)
175232
except XTBException:

xtb/ase/test_calculator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,22 @@ def test_gfn2_xtb_0d():
8080
])
8181
dipole_moment = np.array([0.62120710, 0.28006659, 0.04465985])
8282

83-
calc = XTB(method="GFN2-xTB")
84-
atoms.set_calculator(calc)
83+
calc = XTB(method="GFN2-xTB", atoms=atoms)
8584

8685
assert approx(atoms.get_potential_energy(), thr) == -592.6794366990786
8786
assert approx(atoms.get_forces(), thr) == forces
8887
assert approx(atoms.get_charges(), thr) == charges
8988
assert approx(atoms.get_dipole_moment(), thr) == dipole_moment
9089

90+
atoms.calc.set(
91+
accuracy=0.1,
92+
electronic_temperature=500.0,
93+
max_iterations=20,
94+
solvent="ch2cl2",
95+
)
96+
97+
assert approx(atoms.get_potential_energy(), thr) == -592.9940608761889
98+
9199

92100
def test_gfn1_xtb_0d():
93101
"""Test ASE interface to GFN1-xTB"""
@@ -140,8 +148,7 @@ def test_gfn1_xtb_0d():
140148
])
141149
dipole_moment = np.array([0.76943477, 0.33021928, 0.05670150])
142150

143-
calc = XTB(method="GFN1-xTB")
144-
atoms.set_calculator(calc)
151+
atoms.calc = XTB(method="GFN1-xTB")
145152

146153
assert approx(atoms.get_potential_energy(), thr) == -632.7363734598027
147154
assert approx(atoms.get_forces(), thr) == forces
@@ -235,6 +242,10 @@ def test_gfn2_xtb_3d():
235242
atoms.set_pbc(False)
236243
assert approx(atoms.get_potential_energy(), thr) == -1121.9196707084955
237244

245+
with raises(InputError):
246+
atoms.positions = np.zeros((len(atoms), 3))
247+
calc.calculate(atoms=atoms, system_changes=["positions"])
248+
238249

239250
def test_invalid_method():
240251
"""GFN-xTB without method number is invalid, should raise an input error"""

xtb/ase/test_optimize.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def test_gfn1xtb_bfgs():
5959
]),
6060
)
6161

62-
calc = XTB(method="GFN1-xTB", accuracy=2.0)
63-
atoms.set_calculator(calc)
62+
atoms.calc = XTB(method="GFN1-xTB", accuracy=2.0, cache_api=False)
6463
opt = BFGS(atoms)
6564
opt.run(fmax=0.1)
6665

@@ -107,7 +106,7 @@ def test_gfn2xtb_lbfgs():
107106
opt.run(fmax=0.1)
108107

109108
assert approx(atoms.get_potential_energy(), thr) == -897.4533662470938
110-
assert approx(np.linalg.norm(atoms.get_forces(), ord=2), thr) == 0.1939329480683042
109+
assert approx(np.linalg.norm(atoms.get_forces(), ord=2), thr) == 0.19359647527783497
111110

112111

113112
def test_gfn2xtb_velocityverlet():
@@ -143,11 +142,17 @@ def test_gfn2xtb_velocityverlet():
143142
]),
144143
)
145144

146-
calc = XTB(method="GFN2-xTB")
145+
calc = XTB(method="GFN2-xTB", cache_api=False)
147146
atoms.set_calculator(calc)
148147

149148
dyn = VelocityVerlet(atoms, timestep=1.0*fs)
150149
dyn.run(20)
151150

152151
assert approx(atoms.get_potential_energy(), thr) == -896.9772346260584
153152
assert approx(atoms.get_kinetic_energy(), thr) == 0.022411127028842362
153+
154+
atoms.calc.set(cache_api=True)
155+
dyn.run(20)
156+
157+
assert approx(atoms.get_potential_energy(), thr) == -896.9913862530841
158+
assert approx(atoms.get_kinetic_energy(), thr) == 0.036580471363852810

0 commit comments

Comments
 (0)