diff --git a/src/gemdat/transitions.py b/src/gemdat/transitions.py index a68b044..a560986 100644 --- a/src/gemdat/transitions.py +++ b/src/gemdat/transitions.py @@ -504,17 +504,6 @@ def _calculate_atom_states( The value corresponds to the index in the `site_coords`. -1 indicates that atom is not at any site. """ - - def _site_radius_iterator(): - for label, radius in site_radius.items(): - if label: - grouped = ((k, site) for k, site in enumerate(sites) if site.label == label) - key, site_group = zip(*grouped) - frac_coords = np.array([site.frac_coords for site in site_group]) - yield frac_coords, np.array(key), radius - else: - yield sites.frac_coords, None, radius - lattice = trajectory.get_lattice() cutoff = max(list(site_radius.values())) @@ -531,10 +520,23 @@ def _site_radius_iterator(): atom_sites = np.full((traj_cart_coords.shape[0]), NOSITE) - for coords, key, radius in _site_radius_iterator(): - cart_coords = lattice.get_cartesian_coords(coords) + for label, radius in site_radius.items(): + if label: + grouped = ((k, site) for k, site in enumerate(sites) if site.label == label) + key, site_group = zip(*grouped) + frac_coords = np.array([site.frac_coords for site in site_group]) + key = np.array(key) + else: + frac_coords = sites.frac_coords + key = None + + cart_coords = lattice.get_cartesian_coords(frac_coords) site_index = periodic_tree.search_tree(cart_coords, radius * site_inner_fraction) + if site_index.size == 0: + warn(f'No floating species in range of {label} ({radius=})', stacklevel=2) + continue + siteno, index = site_index.T if key is not None: