Skip to content

Potential Index error when subsampling correlated data #787

@schuhmc

Description

@schuhmc

I believe openmmtools is affected by their issue choderalab/pymbar#552 .

I ran into at least two cases where I would get the following error trying to calculate free energies from a ReplicaExchangeAnalyzer:

Error message
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:413, in CachedProperty.__get__(self, instance, owner_class)
    412 try:
--> 413     value = instance._cache[self.name]
    414 except KeyError:

KeyError: 'mbar'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:413, in CachedProperty.__get__(self, instance, owner_class)
    412 try:
--> 413     value = instance._cache[self.name]
    414 except KeyError:

KeyError: 'unbiased_decorrelated_u_ln'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:413, in CachedProperty.__get__(self, instance, owner_class)
    412 try:
--> 413     value = instance._cache[self.name]
    414 except KeyError:

KeyError: 'decorrelated_u_ln'

During handling of the above exception, another exception occurred:

IndexError                                Traceback (most recent call last)
Cell In[44], line 9
      1 reporter = multistate.MultiStateReporter(storage_file, 'r')
      2 analyzer = multistate.ReplicaExchangeAnalyzer(reporter, analysis_kwargs={'solver_protocol':'robust'})
      4 print(f"""
      5 Loaded {storage_file} with {analyzer.n_states} states and {analyzer.n_replicas} replicas.
      6 Aggregate simulation time {analyzer.n_states * analyzer.n_iterations * fs_per_step*1e-6 * steps_per_iteration} ns assuming {steps_per_iteration} steps * {fs_per_step} fs/step = {steps_per_iteration * fs_per_step * 1e-6} ns per iteration
      7 Iterations: {analyzer.n_iterations}
      8 Statistical inefficiency: {analyzer.statistical_inefficiency}
----> 9 Free energy estimate {(analyzer.get_free_energy()[0][0][-1] * analyzer.kT).value_in_unit(unit.kilojoule / unit.mole):.2f} +/- {(analyzer.get_free_energy()[1][0][-1] * analyzer.kT).value_in_unit(unit.kilojoule / unit.mole):.2f} kJ/mol""")

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:1972, in MultiStateSamplerAnalyzer.get_free_energy(self)
   1959 """
   1960 Compute the free energy and error in free energy from the MBAR object
   1961 
   (...)
   1969     Error in the difference in free energy from each state relative to each other state
   1970 """
   1971 if self._computed_observables['free_energy'] is None:
-> 1972     self._compute_free_energy()
   1973 free_energy_dict = self._computed_observables['free_energy']
   1974 return free_energy_dict['value'], free_energy_dict['error']

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:1923, in MultiStateSamplerAnalyzer._compute_free_energy(self)
   1919 def _compute_free_energy(self):
   1920     """
   1921     Estimate free energies of all alchemical states.
   1922     """
-> 1923     nstates = self.mbar.N_k.size
   1925     # Get matrix of dimensionless free energy differences and uncertainty estimate.
   1926     logger.debug("Computing covariance matrix...")

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:415, in CachedProperty.__get__(self, instance, owner_class)
    413     value = instance._cache[self.name]
    414 except KeyError:
--> 415     value = self._get_default(instance)
    416     # Cache default value for next use.
    417     instance._update_cache(self.name, value, self._check_changes)

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:436, in CachedProperty._get_default(self, instance)
    434     raise AttributeError(err_msg)
    435 elif callable(self._default):
--> 436     value = self._default(self, instance)
    437 else:
    438     value = self._default

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:2185, in MultiStateSamplerAnalyzer.mbar(self, instance)
   2179 @mbar.default
   2180 def mbar(self, instance):
   2181     # u_ln[l,n] is the reduced potential for concatenated decorrelated snapshot n evaluated at thermodynamic state l
   2182     # Shape is (n_states + n_unsampled_states, n_samples)
   2183     # N_l[l] is the number of decorrelated samples sampled from thermodynamic state l, some can be 0.
   2184     # Shape is (n_states + n_unsampled_states, )
-> 2185     return instance._create_mbar(instance._unbiased_decorrelated_u_ln,
   2186                                  instance._unbiased_decorrelated_N_l)

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:415, in CachedProperty.__get__(self, instance, owner_class)
    413     value = instance._cache[self.name]
    414 except KeyError:
--> 415     value = self._get_default(instance)
    416     # Cache default value for next use.
    417     instance._update_cache(self.name, value, self._check_changes)

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:436, in CachedProperty._get_default(self, instance)
    434     raise AttributeError(err_msg)
    435 elif callable(self._default):
--> 436     value = self._default(self, instance)
    437 else:
    438     value = self._default

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:2161, in MultiStateSamplerAnalyzer._unbiased_decorrelated_u_ln(self, instance)
   2159 @_unbiased_decorrelated_u_ln.default
   2160 def _unbiased_decorrelated_u_ln(self, instance):
-> 2161     return instance._compute_mbar_unbiased_energies()[0]

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:1586, in MultiStateSamplerAnalyzer._compute_mbar_unbiased_energies(self)
   1584         unbias_restraint = False
   1585 if not unbias_restraint:
-> 1586     self._unbiased_decorrelated_u_ln = self._decorrelated_u_ln
   1587     self._unbiased_decorrelated_N_l = self._decorrelated_N_l
   1588     return self._unbiased_decorrelated_u_ln, self._unbiased_decorrelated_N_l

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:415, in CachedProperty.__get__(self, instance, owner_class)
    413     value = instance._cache[self.name]
    414 except KeyError:
--> 415     value = self._get_default(instance)
    416     # Cache default value for next use.
    417     instance._update_cache(self.name, value, self._check_changes)

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:436, in CachedProperty._get_default(self, instance)
    434     raise AttributeError(err_msg)
    435 elif callable(self._default):
--> 436     value = self._default(self, instance)
    437 else:
    438     value = self._default

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:2142, in MultiStateSamplerAnalyzer._decorrelated_u_ln(self, instance)
   2140 @_decorrelated_u_ln.default
   2141 def _decorrelated_u_ln(self, instance):
-> 2142     return instance._compute_mbar_decorrelated_energies()[0]

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/multistateanalyzer.py:1509, in MultiStateSamplerAnalyzer._compute_mbar_decorrelated_energies(self)
   1507         energies = multistate.utils.remove_unequilibrated_data(energies, number_equilibrated, -1)
   1508         # Subsample along the decorrelation data.
-> 1509         energy_data[i] = multistate.utils.subsample_data_along_axis(energies, g_t, -1)
   1510 sampled_energy_matrix, unsampled_energy_matrix, neighborhood, replicas_state_indices = energy_data
   1512 # Initialize the MBAR matrices in ln form.

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/openmmtools/multistate/utils.py:293, in subsample_data_along_axis(data, subsample_rate, axis)
    291 data_shape = cast_data.shape
    292 # Since we already have g, we can just pass any appropriate shape to the subsample function
--> 293 indices = subsample_correlated_data(np.zeros(data_shape[axis]), g=subsample_rate)
    294 subsampled_data = np.take(cast_data, indices, axis=axis)
    295 return subsampled_data

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/pymbar/timeseries.py:752, in subsample_correlated_data(A_t, g, fast, conservative, verbose)
    750 t = int(round(n * g))
    751 # ensure we don't sample the same point twice
--> 752 if (n == 0) or (t != indices[n - 1]):
    753     indices.append(t)
    754 n += 1

IndexError: list index out of range

The error easy to replicate as such

import numpy as np
from openmmtools.multistate import utils

g_t = np.float32(1.0005244)
utils.subsample_correlated_data(np.zeros(9500), g=g_t)

Output:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[54], line 5
      2 from openmmtools.multistate import utils
      4 g_t = np.float32(1.0005244)
----> 5 utils.subsample_correlated_data(np.zeros(9500), g=g_t)

File ~/miniforge3/envs/openmm/lib/python3.10/site-packages/pymbar/timeseries.py:752, in subsample_correlated_data(A_t, g, fast, conservative, verbose)
    750 t = int(round(n * g))
    751 # ensure we don't sample the same point twice
--> 752 if (n == 0) or (t != indices[n - 1]):
    753     indices.append(t)
    754 n += 1

IndexError: list index out of range

Now, pymbar has fixed this bug already in choderalab/pymbar#558 but it is not yet part of the current release.

One possible fix would be to use conservative=True when calling subsample_correlated_data() here:

indices = subsample_correlated_data(np.zeros(data_shape[axis]), g=subsample_rate)

But that would reduce the number of effective samples a lot when the statistical inefficiency is close to 1. An alternative would be to catch the exception and then either don't do any subsampling when g_t~1, or fall back to the conservative method.

Note that this also breaks running replica-exchange/multistate simulations when the online_analysis is done.

Setting analyzer.use_full_trajectory = True is also not a suitable fix as it then does not discard equilibration frames and the issue with running simulations is also not fixed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions