Skip to content

Update captions for plots #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import numpy as np
from typing import TYPE_CHECKING
from collections import defaultdict
if TYPE_CHECKING:
from gemdat.trajectory import Trajectory


def _mean_displacements_per_element(
trajectory: Trajectory) -> dict[str, tuple[np.ndarray, np.ndarray]]:
"""Calculate mean displacements per element type.

Helper for `displacement_per_atom`.
"""
species = trajectory.species

grouped = defaultdict(list)
for sp, distances in zip(species,
trajectory.distances_from_base_position()):
grouped[sp.symbol].append(distances)

means = {}
for sp, dists in grouped.items():
mean = np.mean(dists, axis=0)
std = np.std(dists, axis=0)

means[sp] = (mean, std)

return means
49 changes: 22 additions & 27 deletions src/gemdat/plots/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,29 @@
matplotlib."""
from __future__ import annotations

from ._displacements import (
displacement_histogram,
displacement_per_atom,
displacement_per_element,
msd_per_element,
)
from ._jumps import (
collective_jumps,
jumps_3d,
jumps_3d_animation,
jumps_vs_distance,
jumps_vs_time,
)
from ._paths import (
energy_along_path,
path_on_grid,
)
from ._rdf import radial_distribution
from ._orientations import (
autocorrelation,
bond_length_distribution,
rectilinear,
)
from ._displacement_histogram import displacement_histogram
from ._displacement_per_atom import displacement_per_atom
from ._displacement_per_element import displacement_per_element
from ._msd_per_element import msd_per_element

from ._collective_jumps import collective_jumps
from ._jumps_3d import jumps_3d
from ._jumps_3d_animation import jumps_3d_animation
from ._jumps_vs_distance import jumps_vs_distance
from ._jumps_vs_time import jumps_vs_time

from ._energy_along_path import energy_along_path
from ._path_on_grid import path_on_grid
from ._radial_distribution import radial_distribution

from ._autocorrelation import autocorrelation
from ._bond_length_distribution import bond_length_distribution
from ._rectilinear import rectilinear

from ._shape import shape
from ._vibration import (
frequency_vs_occurence,
vibrational_amplitudes,
)

from ._frequency_vs_occurence import frequency_vs_occurence
from ._vibrational_amplitudes import vibrational_amplitudes

__all__ = [
'bond_length_distribution',
Expand Down
57 changes: 57 additions & 0 deletions src/gemdat/plots/matplotlib/_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np

from gemdat.orientations import (
Orientations, )


def autocorrelation(
*,
orientations: Orientations,
show_traces: bool = True,
) -> plt.Figure:
"""Plot the autocorrelation function of the unit vectors series.

Parameters
----------
orientations : Orientations
The unit vector trajectories
show_traces : bool
If True, show traces of individual trajectories

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
ac = orientations.autocorrelation()
ac_std = ac.std(axis=0)
ac_mean = ac.mean(axis=0)

# Since we want to plot in picosecond, we convert the time units
time_ps = orientations._time_step * 1e12
tgrid = np.arange(ac_mean.shape[0]) * time_ps

fig, ax = plt.subplots()

ax.plot(tgrid, ac_mean, label='FFT autocorrelation')

last_color = ax.lines[-1].get_color()

if show_traces:
for i, ac_i in enumerate(ac):
label = 'Trajectories' if (i == 0) else None
ax.plot(tgrid, ac_i, lw=0.1, c=last_color, label=label)

ax.fill_between(tgrid,
ac_mean - ac_std,
ac_mean + ac_std,
alpha=0.2,
label='Standard deviation')
ax.set_xlabel('Time lag (ps)')
ax.set_ylabel('Autocorrelation')
ax.legend()

return fig
65 changes: 65 additions & 0 deletions src/gemdat/plots/matplotlib/_bond_length_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
from scipy.stats import skewnorm

from gemdat.orientations import Orientations


def bond_length_distribution(*,
orientations: Orientations,
bins: int = 1000) -> plt.Figure:
"""Plot the bond length probability distribution.

Parameters
----------
orientations : Orientations
The unit vector trajectories
bins : int, optional
The number of bins, by default 1000

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
*_, bond_lengths = orientations.vectors_spherical.T
bond_lengths = bond_lengths.flatten()

fig, ax = plt.subplots()

# Plot the normalized histogram
hist, edges = np.histogram(bond_lengths, bins=bins, density=True)
bin_centers = (edges[:-1] + edges[1:]) / 2

# Fit a skewed Gaussian distribution to the orientations
params, covariance = curve_fit(skewnorm.pdf,
bin_centers,
hist,
p0=[1.5, 1, 1.5])

# Create a new function using the fitted parameters
def _skewnorm_fit(x):
return skewnorm.pdf(x, *params)

# Plot the histogram
ax.hist(bond_lengths,
bins=bins,
density=True,
color='blue',
alpha=0.7,
label='Data')

# Plot the fitted skewed Gaussian distribution
x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000)
ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit')

ax.set_xlabel('Bond length (Å)')
ax.set_ylabel(r'Probability density (Å$^{-1}$)')
ax.set_title('Bond Length Probability Distribution')
ax.legend()
ax.grid(True)

return fig
42 changes: 42 additions & 0 deletions src/gemdat/plots/matplotlib/_collective_jumps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt

if TYPE_CHECKING:
from gemdat import Jumps


def collective_jumps(*, jumps: Jumps) -> plt.Figure:
"""Plot collective jumps per jump-type combination.

Parameters
----------
jumps : Jumps
Input data

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
fig, ax = plt.subplots()

collective = jumps.collective()

matrix = collective.site_pair_count_matrix()

im = ax.imshow(matrix)

labels = collective.site_pair_count_matrix_labels()
ticks = range(len(labels))

ax.set_xticks(ticks, labels=labels, rotation=90)
ax.set_yticks(ticks, labels=labels)

fig.colorbar(im, ax=ax)

ax.set(title='Cooperative jumps per jump-type combination')

return fig
27 changes: 27 additions & 0 deletions src/gemdat/plots/matplotlib/_displacement_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory


def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
"""Plot histogram of total displacement at final timestep.

Parameters
----------
trajectory : Trajectory
Input trajectory, i.e. for the diffusing atom

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
fig, ax = plt.subplots()
ax.hist(trajectory.distances_from_base_position()[:, -1])
ax.set(title='Displacement per element',
xlabel='Displacement (Å)',
ylabel='Nr. of atoms')

return fig
30 changes: 30 additions & 0 deletions src/gemdat/plots/matplotlib/_displacement_per_atom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory


def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
"""Plot displacement per atom.

Parameters
----------
trajectory : Trajectory
Input trajectory, i.e. for the diffusing atom

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
fig, ax = plt.subplots()

for distances in trajectory.distances_from_base_position():
ax.plot(distances, lw=0.3)

ax.set(title='Displacement per site',
xlabel='Time step',
ylabel='Displacement (Å)')

return fig
35 changes: 35 additions & 0 deletions src/gemdat/plots/matplotlib/_displacement_per_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import matplotlib.pyplot as plt

from gemdat.trajectory import Trajectory

from gemdat.plots._shared import _mean_displacements_per_element


def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
"""Plot displacement per element.

Parameters
----------
trajectory : Trajectory
Input trajectory

Returns
-------
fig : matplotlib.figure.Figure
Output figure
"""
displacements = _mean_displacements_per_element(trajectory)

fig, ax = plt.subplots()

for symbol, (mean, _) in displacements.items():
ax.plot(mean, lw=0.3, label=symbol)

ax.legend()
ax.set(title='Displacement per element',
xlabel='Time step',
ylabel='Displacement (Å)')

return fig
Loading
Loading