Skip to content

Commit c1888d8

Browse files
authored
Refactor radial_distribution_between_species() (#340)
* Refactor radial_distribution_between_species() * Add shortcut to radial distribution on trajectory * Fix test fail
1 parent ea6b841 commit c1888d8

15 files changed

+102
-199
lines changed

src/gemdat/plots/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
msd_per_element,
2222
plot_3d,
2323
radial_distribution,
24-
radial_distribution_between_species,
2524
rectilinear,
2625
shape,
2726
vibrational_amplitudes,
@@ -45,7 +44,6 @@
4544
'msd_per_element',
4645
'plot_3d',
4746
'radial_distribution',
48-
'radial_distribution_between_species',
4947
'rectilinear',
5048
'shape',
5149
'vibrational_amplitudes',

src/gemdat/plots/_shared.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from scipy.stats import skewnorm
1111

1212
if TYPE_CHECKING:
13-
from typing import Collection
14-
1513
from gemdat.orientations import Orientations
1614
from gemdat.trajectory import Trajectory
1715

@@ -149,45 +147,3 @@ def _get_vibrational_amplitudes_hist(
149147
std = np.std(data, axis=0)
150148

151149
return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std)
152-
153-
154-
def _get_radial_distribution_between_species(
155-
*,
156-
trajectory: Trajectory,
157-
specie_1: str | Collection[str],
158-
specie_2: str | Collection[str],
159-
max_dist: float = 5.0,
160-
resolution: float = 0.1,
161-
) -> tuple[np.ndarray, np.ndarray]:
162-
coords_1 = trajectory.filter(specie_1).coords
163-
coords_2 = trajectory.filter(specie_2).coords
164-
lattice = trajectory.get_lattice()
165-
166-
if coords_2.ndim == 2:
167-
num_time_steps = 1
168-
num_atoms, num_dimensions = coords_2.shape
169-
else:
170-
num_time_steps, num_atoms, num_dimensions = coords_2.shape
171-
172-
particle_vol = num_atoms / lattice.volume
173-
174-
all_dists = np.concatenate(
175-
[
176-
lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
177-
for t in range(num_time_steps)
178-
]
179-
)
180-
distances = all_dists.flatten()
181-
182-
bins = np.arange(0, max_dist + resolution, resolution)
183-
rdf, _ = np.histogram(distances, bins=bins, density=False)
184-
185-
def normalize(radius: np.ndarray) -> np.ndarray:
186-
"""Normalize bin to volume."""
187-
shell = (radius + resolution) ** 3 - radius**3
188-
return particle_vol * (4 / 3) * np.pi * shell
189-
190-
norm = normalize(bins)[:-1]
191-
rdf = rdf / norm
192-
193-
return bins, rdf

src/gemdat/plots/matplotlib/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ._jumps_vs_time import jumps_vs_time
1818
from ._msd_per_element import msd_per_element
1919
from ._radial_distribution import radial_distribution
20-
from ._radial_distribution_between_species import radial_distribution_between_species
2120
from ._rectilinear import rectilinear
2221
from ._shape import shape
2322
from ._vibrational_amplitudes import vibrational_amplitudes
@@ -37,7 +36,6 @@
3736
'jumps_vs_time',
3837
'msd_per_element',
3938
'radial_distribution',
40-
'radial_distribution_between_species',
4139
'rectilinear',
4240
'shape',
4341
'vibrational_amplitudes',

src/gemdat/plots/matplotlib/_radial_distribution.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> matplotlib.figure.Figure:
2626
fig, ax = plt.subplots()
2727

2828
for rdf in rdfs:
29-
ax.plot(rdf.x, rdf.y, label=rdf.symbol)
29+
ax.plot(rdf.x, rdf.y, label=rdf.label)
3030

31-
states = ', '.join({rdf.state for rdf in rdfs})
31+
states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
32+
state_suffix = f' ({states})' if states else ''
3233

3334
ax.legend()
3435
ax.set(
35-
title=f'Radial distribution function ({states})',
36+
title=f'Radial distribution function{state_suffix}',
3637
xlabel='Distance (Å)',
3738
ylabel='Counts',
3839
)

src/gemdat/plots/matplotlib/_radial_distribution_between_species.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

src/gemdat/plots/plotly/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ._msd_per_element import msd_per_element
1818
from ._plot3d import plot_3d
1919
from ._radial_distribution import radial_distribution
20-
from ._radial_distribution_between_species import radial_distribution_between_species
2120
from ._rectilinear import rectilinear
2221
from ._shape import shape
2322
from ._vibrational_amplitudes import vibrational_amplitudes
@@ -38,7 +37,6 @@
3837
'msd_per_element',
3938
'plot_3d',
4039
'radial_distribution',
41-
'radial_distribution_between_species',
4240
'rectilinear',
4341
'shape',
4442
'vibrational_amplitudes',

src/gemdat/plots/plotly/_radial_distribution.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> go.Figure:
2828
go.Scatter(
2929
x=rdf.x,
3030
y=rdf.y,
31-
name=rdf.symbol,
31+
name=rdf.label,
3232
mode='lines',
3333
# line={'width': 0.25}
3434
)
3535
)
3636

37-
states = ', '.join({rdf.state for rdf in rdfs})
37+
states = ', '.join({rdf.state for rdf in rdfs if rdf.state})
38+
state_suffix = f' ({states})' if states else ''
39+
3840
fig.update_layout(
39-
title=f'Radial distribution function ({states})',
41+
title=f'Radial distribution function{state_suffix}',
4042
xaxis_title='Distance (Å)',
4143
yaxis_title='Counts',
4244
)

src/gemdat/plots/plotly/_radial_distribution_between_species.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

src/gemdat/rdf.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from ._plot_backend import plot_backend
1111

1212
if TYPE_CHECKING:
13+
from typing import Collection
14+
1315
from pymatgen.core import Structure
1416

17+
from gemdat import Trajectory
1518
from gemdat.transitions import Transitions
1619

1720

@@ -79,16 +82,16 @@ class RDFData:
7982
1D array with x data (bins)
8083
y : np.ndarray
8184
1D array with y data (counts)
82-
symbol : str
83-
Distance to species with this symbol
85+
label : str
86+
Distance to species with this symbol label
8487
state : str
8588
State that the floating species is in, e.g.
8689
the jump that it is making.
8790
"""
8891

8992
x: np.ndarray
9093
y: np.ndarray
91-
symbol: str
94+
label: str
9295
state: str
9396

9497
@plot_backend
@@ -178,10 +181,75 @@ def radial_distribution(
178181
x=bins,
179182
# Drop last element with distance > max_dist
180183
y=values[:-1],
181-
symbol=symbol,
184+
label=symbol,
182185
state=state,
183186
)
184187
ret.setdefault(state, RDFCollection())
185188
ret[state].append(rdf_data)
186189

187190
return ret
191+
192+
193+
def radial_distribution_between_species(
194+
*,
195+
trajectory: Trajectory,
196+
specie_1: str | Collection[str],
197+
specie_2: str | Collection[str],
198+
max_dist: float = 5.0,
199+
resolution: float = 0.1,
200+
) -> RDFData:
201+
"""Calculate RDFs from specie_1 to specie_2.
202+
203+
Parameters
204+
----------
205+
trajectory: Trajectory
206+
Input trajectory.
207+
specie_1: str | list[str]
208+
Name of specie or list of species
209+
specie_2: str | list[str]
210+
Name of specie or list of species
211+
max_dist: float, optional
212+
Max distance for rdf calculation
213+
resolution: float, optional
214+
Width of the bins
215+
216+
Returns
217+
-------
218+
rdf : RDFData
219+
RDF data for the given species.
220+
"""
221+
coords_1 = trajectory.filter(specie_1).coords
222+
coords_2 = trajectory.filter(specie_2).coords
223+
lattice = trajectory.get_lattice()
224+
225+
if coords_2.ndim == 2:
226+
num_time_steps = 1
227+
num_atoms, num_dimensions = coords_2.shape
228+
else:
229+
num_time_steps, num_atoms, num_dimensions = coords_2.shape
230+
231+
particle_vol = num_atoms / lattice.volume
232+
233+
all_dists = np.concatenate(
234+
[
235+
lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :])
236+
for t in range(num_time_steps)
237+
]
238+
)
239+
distances = all_dists.flatten()
240+
241+
bins = np.arange(0, max_dist + resolution, resolution)
242+
rdf, _ = np.histogram(distances, bins=bins, density=False)
243+
244+
def normalize(radius: np.ndarray) -> np.ndarray:
245+
"""Normalize bin to volume."""
246+
shell = (radius + resolution) ** 3 - radius**3
247+
return particle_vol * (4 / 3) * np.pi * shell
248+
249+
norm = normalize(bins)[:-1]
250+
counts = rdf / norm
251+
252+
str1 = specie_1 if isinstance(specie_1, str) else '/'.join(specie_1)
253+
str2 = specie_1 if isinstance(specie_2, str) else '/'.join(specie_2)
254+
255+
return RDFData(x=bins[:-1], y=counts, label=f'{str1}-{str2}', state='')

0 commit comments

Comments
 (0)