diff --git a/src/gemdat/jumps.py b/src/gemdat/jumps.py index 81b0994..4f1837f 100644 --- a/src/gemdat/jumps.py +++ b/src/gemdat/jumps.py @@ -38,60 +38,52 @@ def _generic_transitions_to_jumps( events = events.rename(columns={'time': 'start time'}) jumps = [] - fromevent = None - candidate_jump = None - - for _, event in events.iterrows(): - # If we are jumping, but we go to the next atom index, reset - if fromevent is not None: - if fromevent['atom index'] != event['atom index']: - fromevent = None - - # If we have a candidate jump, we must make sure it remains on the site - # for minimal_residence timesteps, this is that check, add it to the jumps - # if it passes - if candidate_jump is not None: - if candidate_jump['atom index'] != event['atom index']: - jumps.append(candidate_jump) - candidate_jump = None - elif event['start time'] - candidate_jump['stop time'] >= minimal_residence: - jumps.append(candidate_jump) - candidate_jump = None - fromevent = None - elif candidate_jump['destination site'] != event['destination site']: - candidate_jump = None - - # Specify the start of a jump if we encounter one - if event['start site'] != -1: - if event['start site'] != event['destination site']: - fromevent = event - - if fromevent is not None: - # if we jump back, remove fromevent - if fromevent['start site'] == event['destination site']: - fromevent = None - candidate_jump = None - continue - - # Check if jump to the inner site, add it to the jumps immediately - if event['destination inner site'] != -1: - event['start site'] = fromevent['start site'] - event['start time'] = fromevent['start time'] - fromevent = None - candidate_jump = None - jumps.append(event) - continue - - # If we enter another site, create a candidate jump - if candidate_jump is None: - if event['destination site'] != -1: + + atom_events = [atomevents for i, atomevents in events.groupby('atom index')] + + for events in atom_events: + fromevent = None + candidate_jump = None + + for _, event in events.iterrows(): + # We have a previous jump, but must still determine + # if it stays on the site long enough + if candidate_jump is not None: + if event['start time'] - candidate_jump['start time'] >= minimal_residence: + jumps.append(candidate_jump) + candidate_jump = None + # it moves to early! dont add the jump + elif candidate_jump['destination site'] != event['destination site']: + candidate_jump = None + + # Identify a transition away from a site, name it fromevent + if event['start site'] != -1: + if event['start site'] != event['destination site']: + fromevent = event + + if fromevent is not None: + # jump to same site is no jump + if event['destination site'] == fromevent['start site']: + fromevent = None + candidate_jump = None + + # jump to another inner site is definitely jump + elif event['destination inner site'] != -1: + event['start site'] = fromevent['start site'] + event['start time'] = fromevent['start time'] + jumps.append(event) + fromevent = None + candidate_jump = None + + # jump to another site is a candidate_event + elif event['destination site'] != fromevent['destination site']: event['start site'] = fromevent['start site'] event['start time'] = fromevent['start time'] candidate_jump = event + fromevent = None - # Also add a last candidate jump (if there is one - if candidate_jump is not None: - jumps.append(candidate_jump) + if len(jumps) == 0: + raise ValueError('No jumps found') jumps = pd.DataFrame(data=jumps) @@ -134,6 +126,7 @@ def __init__( self.sites = transitions.sites self.conversion_method = conversion_method self.data = conversion_method(transitions, minimal_residence=minimal_residence) + self.minimal_residence = minimal_residence @property def n_jumps(self) -> int: @@ -408,7 +401,14 @@ def split(self, n_parts: int) -> list[Jumps]: """ parts = self.transitions.split(n_parts) - return [Jumps(part, conversion_method=self.conversion_method) for part in parts] + return [ + Jumps( + part, + conversion_method=self.conversion_method, + minimal_residence=self.minimal_residence, + ) + for part in parts + ] @weak_lru_cache() def rates(self, n_parts: int = 10) -> pd.DataFrame: diff --git a/tests/integration/jumps_test.py b/tests/integration/jumps_test.py index 1713be6..b09e0ff 100644 --- a/tests/integration/jumps_test.py +++ b/tests/integration/jumps_test.py @@ -20,14 +20,14 @@ def test_site_inner_fraction(self, vasp_traj, structure): ) jumps = Jumps(transitions=transitions, minimal_residence=100) - assert len(jumps.data) == 267 + assert len(jumps.data) == 258 assert np.all( jumps.data[::100].to_numpy() == np.array( [ - [0, 94, 0, 282, 303], - [15, 74, 8, 1271, 1286], - [34, 49, 33, 3141, 3296], + [0, 0, 94, 385, 442], + [16, 34, 42, 3268, 3298], + [36, 19, 11, 1763, 1789], ] ) )