diff --git a/src/gemdat/plots/__init__.py b/src/gemdat/plots/__init__.py index 5374af0..edb86f9 100644 --- a/src/gemdat/plots/__init__.py +++ b/src/gemdat/plots/__init__.py @@ -21,6 +21,7 @@ msd_per_element, plot_3d, radial_distribution, + radial_distribution_between_species, rectilinear, shape, vibrational_amplitudes, @@ -44,6 +45,7 @@ 'msd_per_element', 'plot_3d', 'radial_distribution', + 'radial_distribution_between_species', 'rectilinear', 'shape', 'vibrational_amplitudes', diff --git a/src/gemdat/plots/_shared.py b/src/gemdat/plots/_shared.py index ec3f427..f97e02b 100644 --- a/src/gemdat/plots/_shared.py +++ b/src/gemdat/plots/_shared.py @@ -10,6 +10,8 @@ from scipy.stats import skewnorm if TYPE_CHECKING: + from typing import Collection + from gemdat.orientations import Orientations from gemdat.trajectory import Trajectory @@ -147,3 +149,45 @@ def _get_vibrational_amplitudes_hist( std = np.std(data, axis=0) return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std) + + +def _get_radial_distribution_between_species( + *, + trajectory: Trajectory, + specie_1: str | Collection[str], + specie_2: str | Collection[str], + max_dist: float = 5.0, + resolution: float = 0.1, +) -> tuple[np.ndarray, np.ndarray]: + coords_1 = trajectory.filter(specie_1).coords + coords_2 = trajectory.filter(specie_2).coords + lattice = trajectory.get_lattice() + + if coords_2.ndim == 2: + num_time_steps = 1 + num_atoms, num_dimensions = coords_2.shape + else: + num_time_steps, num_atoms, num_dimensions = coords_2.shape + + particle_vol = num_atoms / lattice.volume + + all_dists = np.concatenate( + [ + lattice.get_all_distances(coords_1[t, :, :], coords_2[t, :, :]) + for t in range(num_time_steps) + ] + ) + distances = all_dists.flatten() + + bins = np.arange(0, max_dist + resolution, resolution) + rdf, _ = np.histogram(distances, bins=bins, density=False) + + def normalize(radius: np.ndarray) -> np.ndarray: + """Normalize bin to volume.""" + shell = (radius + resolution) ** 3 - radius**3 + return particle_vol * (4 / 3) * np.pi * shell + + norm = normalize(bins)[:-1] + rdf = rdf / norm + + return bins, rdf diff --git a/src/gemdat/plots/matplotlib/__init__.py b/src/gemdat/plots/matplotlib/__init__.py index d364361..f66e507 100644 --- a/src/gemdat/plots/matplotlib/__init__.py +++ b/src/gemdat/plots/matplotlib/__init__.py @@ -17,6 +17,7 @@ from ._jumps_vs_time import jumps_vs_time from ._msd_per_element import msd_per_element from ._radial_distribution import radial_distribution +from ._radial_distribution_between_species import radial_distribution_between_species from ._rectilinear import rectilinear from ._shape import shape from ._vibrational_amplitudes import vibrational_amplitudes @@ -36,6 +37,7 @@ 'jumps_vs_time', 'msd_per_element', 'radial_distribution', + 'radial_distribution_between_species', 'rectilinear', 'shape', 'vibrational_amplitudes', diff --git a/src/gemdat/plots/matplotlib/_radial_distribution_between_species.py b/src/gemdat/plots/matplotlib/_radial_distribution_between_species.py new file mode 100644 index 0000000..ea6b220 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_radial_distribution_between_species.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt + +from .._shared import _get_radial_distribution_between_species + +if TYPE_CHECKING: + from typing import Collection + + import matplotlib.figure + + from gemdat import Trajectory + + +def radial_distribution_between_species( + *, + trajectory: Trajectory, + specie_1: str | Collection[str], + specie_2: str | Collection[str], + max_dist: float = 5.0, + resolution: float = 0.1, +) -> matplotlib.figure.Figure: + """Calculate RDFs from specie_1 to specie_2. + + Parameters + ---------- + trajectory: Trajectory + Input trajectory. + specie_1: str | list[str] + Name of specie or list of species + specie_2: str | list[str] + Name of specie or list of species + max_dist: float, optional + Max distance for rdf calculation + resolution: float, optional + Width of the bins + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + bins, rdf = _get_radial_distribution_between_species( + trajectory=trajectory, + specie_1=specie_1, + specie_2=specie_2, + max_dist=max_dist, + resolution=resolution, + ) + + fig, ax = plt.subplots() + ax.plot(bins[:-1], rdf) + str1 = specie_1 if isinstance(specie_1, str) else ' / '.join(specie_1) + str2 = specie_1 if isinstance(specie_2, str) else ' / '.join(specie_2) + ax.set( + title=f'RDF between {str1} and {str2}', + xlabel='Radius (Å)', + ylabel='Nr. of atoms', + ) + return fig diff --git a/src/gemdat/plots/plotly/__init__.py b/src/gemdat/plots/plotly/__init__.py index 6cafdf2..afc852f 100644 --- a/src/gemdat/plots/plotly/__init__.py +++ b/src/gemdat/plots/plotly/__init__.py @@ -17,6 +17,7 @@ from ._msd_per_element import msd_per_element from ._plot3d import plot_3d from ._radial_distribution import radial_distribution +from ._radial_distribution_between_species import radial_distribution_between_species from ._rectilinear import rectilinear from ._shape import shape from ._vibrational_amplitudes import vibrational_amplitudes @@ -32,12 +33,12 @@ 'energy_along_path', 'frequency_vs_occurence', 'jumps_3d', - 'jumps_3d_animation', 'jumps_vs_distance', 'jumps_vs_time', 'msd_per_element', 'plot_3d', 'radial_distribution', + 'radial_distribution_between_species', 'rectilinear', 'shape', 'vibrational_amplitudes', diff --git a/src/gemdat/plots/plotly/_radial_distribution_between_species.py b/src/gemdat/plots/plotly/_radial_distribution_between_species.py new file mode 100644 index 0000000..f4f6816 --- /dev/null +++ b/src/gemdat/plots/plotly/_radial_distribution_between_species.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import plotly.graph_objects as go + +from .._shared import _get_radial_distribution_between_species + +if TYPE_CHECKING: + from typing import Collection + + from gemdat import Trajectory + + +def radial_distribution_between_species( + trajectory: Trajectory, + specie_1: str | Collection[str], + specie_2: str | Collection[str], + max_dist: float = 5.0, + resolution: float = 0.1, +) -> go.Figure: + """Calculate RDFs from specie_1 to specie_2. + + Parameters + ---------- + trajectory: Trajectory + Input trajectory. + specie_1: str | list[str] + Name of specie or list of species + specie_2: str | list[str] + Name of specie or list of species + max_dist: float, optional + Max distance for rdf calculation + resolution: float, optional + Width of the bins + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + bins, rdf = _get_radial_distribution_between_species( + trajectory=trajectory, + specie_1=specie_1, + specie_2=specie_2, + max_dist=max_dist, + resolution=resolution, + ) + + fig = go.Figure() + fig.add_trace( + go.Scatter( + x=bins, + y=rdf, + name='Radial distribution', + mode='lines', + ) + ) + str1 = specie_1 if isinstance(specie_1, str) else ' / '.join(specie_1) + str2 = specie_1 if isinstance(specie_2, str) else ' / '.join(specie_2) + fig.update_layout( + title=f'RDF between {str1} and {str2}', + xaxis_title='Radius (Å)', + yaxis_title='Nr. of atoms', + ) + return fig diff --git a/src/gemdat/trajectory.py b/src/gemdat/trajectory.py index 5d6cf41..71ba5fd 100644 --- a/src/gemdat/trajectory.py +++ b/src/gemdat/trajectory.py @@ -758,6 +758,12 @@ def plot_frequency_vs_occurence(self, *, module, **kwargs): """See [gemdat.plots.frequency_vs_occurence][] for more info.""" return module.frequency_vs_occurence(trajectory=self, **kwargs) + @plot_backend + def plot_radial_distribution_between_species(self, *, module, **kwargs): + """See [gemdat.plots.radial_distribution_between_species][] for more + info.""" + return module.radial_distribution_between_species(trajectory=self, **kwargs) + @plot_backend def plot_vibrational_amplitudes(self, *, module, **kwargs): """See [gemdat.plots.vibrational_amplitudes][] for more info.""" diff --git a/tests/integration/baseline_images/plot_mpl_test/radial_distribution_between_species.png b/tests/integration/baseline_images/plot_mpl_test/radial_distribution_between_species.png new file mode 100644 index 0000000..c177e65 Binary files /dev/null and b/tests/integration/baseline_images/plot_mpl_test/radial_distribution_between_species.png differ diff --git a/tests/integration/baseline_images/plot_plotly_test/radial_distribution_between_species.png b/tests/integration/baseline_images/plot_plotly_test/radial_distribution_between_species.png new file mode 100644 index 0000000..6b4c698 Binary files /dev/null and b/tests/integration/baseline_images/plot_plotly_test/radial_distribution_between_species.png differ diff --git a/tests/integration/plot_mpl_test.py b/tests/integration/plot_mpl_test.py index d3caf61..5eee168 100644 --- a/tests/integration/plot_mpl_test.py +++ b/tests/integration/plot_mpl_test.py @@ -75,6 +75,14 @@ def test_radial_distribution(vasp_rdf_data): rdfs.plot(backend=BACKEND) +@image_comparison2(baseline_images=['radial_distribution_between_species']) +def test_radial_distribution_between_species(vasp_traj): + traj = vasp_traj[-500:] + traj.plot_radial_distribution_between_species( + backend=BACKEND, specie_1='Li', specie_2=('S', 'CL') + ) + + @image_comparison2(baseline_images=['shape']) def test_shape(vasp_shape_data): assert len(vasp_shape_data) == 1 diff --git a/tests/integration/plot_plotly_test.py b/tests/integration/plot_plotly_test.py index a37d277..64723b7 100644 --- a/tests/integration/plot_plotly_test.py +++ b/tests/integration/plot_plotly_test.py @@ -74,6 +74,17 @@ def test_radial_distribution(vasp_rdf_data): assert_figures_similar(fig, name=f'radial_distribution_{i}', rms=0.5) +def test_radial_distribution_between_species(vasp_traj): + traj = vasp_traj[-500:] + fig = traj.plot_radial_distribution_between_species( + backend=BACKEND, + specie_1='Li', + specie_2=('S', 'CL'), + ) + + assert_figures_similar(fig, name='radial_distribution_between_species', rms=0.5) + + def test_shape(vasp_shape_data): assert len(vasp_shape_data) == 1 for shape in vasp_shape_data: