Skip to content

Fix crash in Transitions.occupancy if input sites are disordered #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/gemdat/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from collections import defaultdict
from itertools import pairwise
from warnings import warn

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -67,6 +68,15 @@ def __init__(
inner_states : np.ndarray
Input states for inner sites
"""
if not (sites.is_ordered):
warn(
'Input `sites` are disordered! '
'Although the code may work, it was written under the assumption '
'that an ordered structure would be passed. '
'See https://github.com/GEMDAT-repos/GEMDAT/issues/339 for more information.',
stacklevel=2,
)

self.sites = sites
self.trajectory = trajectory
self.diff_trajectory = diff_trajectory
Expand Down Expand Up @@ -252,7 +262,10 @@ def occupancy(self) -> Structure:
counts = counts / len(states)
occupancies = dict(zip(unq, counts))

species = [{site.specie.name: occupancies.get(i, 0)} for i, site in enumerate(sites)]
species = [
{site.species.elements[0].name: occupancies.get(i, 0)}
for i, site in enumerate(sites)
]

return Structure(
lattice=sites.lattice,
Expand All @@ -262,27 +275,36 @@ def occupancy(self) -> Structure:
labels=sites.labels,
)

def atom_locations(self):
def occupancy_by_site_type(self) -> dict[str, float]:
"""Calculate average occupancy per a type of site.

Returns
-------
occupancy : dict[str, float]
Return dict with average occupancy per site type
"""
compositions_by_label = defaultdict(list)

for site in self.occupancy():
compositions_by_label[site.label].append(site.species.num_atoms)

return {k: sum(v) / len(v) for k, v in compositions_by_label.items()}

def atom_locations(self) -> dict[str, float]:
"""Calculate fraction of time atoms spent at a type of site.

Returns
-------
dict[str, float]
Return dict with the fraction of time atoms spent at a site
"""
multiplier = len(self.sites) / self.n_floating

n = self.n_floating
compositions_by_label = defaultdict(list)

for site in self.occupancy():
compositions_by_label[site.label].append(site.species.num_atoms)

ret = {}

for k, v in compositions_by_label.items():
ret[k] = (sum(v) / len(v)) * multiplier

return ret
return {k: sum(v) / n for k, v in compositions_by_label.items()}

def split(self, n_parts: int = 10) -> list[Transitions]:
"""Split data into equal parts in time for statistics.
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/transitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def test_occupancy_parts(self, vasp_transitions):
35.43733333333334,
]

def test_occupancy_by_site_type(self, vasp_transitions):
occ = vasp_transitions.occupancy_by_site_type()
assert occ == {'48h': 0.3806277777777776}

def test_atom_locations(self, vasp_transitions):
dct = vasp_transitions.atom_locations()
assert dct == {'48h': 0.7612555555555552}
Expand Down