diff --git a/src/gemdat/plots/_shared.py b/src/gemdat/plots/_shared.py
new file mode 100644
index 00000000..4da40845
--- /dev/null
+++ b/src/gemdat/plots/_shared.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import numpy as np
+from typing import TYPE_CHECKING
+from collections import defaultdict
+if TYPE_CHECKING:
+ from gemdat.trajectory import Trajectory
+
+
+def _mean_displacements_per_element(
+ trajectory: Trajectory) -> dict[str, tuple[np.ndarray, np.ndarray]]:
+ """Calculate mean displacements per element type.
+
+ Helper for `displacement_per_atom`.
+ """
+ species = trajectory.species
+
+ grouped = defaultdict(list)
+ for sp, distances in zip(species,
+ trajectory.distances_from_base_position()):
+ grouped[sp.symbol].append(distances)
+
+ means = {}
+ for sp, dists in grouped.items():
+ mean = np.mean(dists, axis=0)
+ std = np.std(dists, axis=0)
+
+ means[sp] = (mean, std)
+
+ return means
diff --git a/src/gemdat/plots/matplotlib/__init__.py b/src/gemdat/plots/matplotlib/__init__.py
index cc90af17..ba17f569 100644
--- a/src/gemdat/plots/matplotlib/__init__.py
+++ b/src/gemdat/plots/matplotlib/__init__.py
@@ -2,34 +2,29 @@
matplotlib."""
from __future__ import annotations
-from ._displacements import (
- displacement_histogram,
- displacement_per_atom,
- displacement_per_element,
- msd_per_element,
-)
-from ._jumps import (
- collective_jumps,
- jumps_3d,
- jumps_3d_animation,
- jumps_vs_distance,
- jumps_vs_time,
-)
-from ._paths import (
- energy_along_path,
- path_on_grid,
-)
-from ._rdf import radial_distribution
-from ._orientations import (
- autocorrelation,
- bond_length_distribution,
- rectilinear,
-)
+from ._displacement_histogram import displacement_histogram
+from ._displacement_per_atom import displacement_per_atom
+from ._displacement_per_element import displacement_per_element
+from ._msd_per_element import msd_per_element
+
+from ._collective_jumps import collective_jumps
+from ._jumps_3d import jumps_3d
+from ._jumps_3d_animation import jumps_3d_animation
+from ._jumps_vs_distance import jumps_vs_distance
+from ._jumps_vs_time import jumps_vs_time
+
+from ._energy_along_path import energy_along_path
+from ._path_on_grid import path_on_grid
+from ._radial_distribution import radial_distribution
+
+from ._autocorrelation import autocorrelation
+from ._bond_length_distribution import bond_length_distribution
+from ._rectilinear import rectilinear
+
from ._shape import shape
-from ._vibration import (
- frequency_vs_occurence,
- vibrational_amplitudes,
-)
+
+from ._frequency_vs_occurence import frequency_vs_occurence
+from ._vibrational_amplitudes import vibrational_amplitudes
__all__ = [
'bond_length_distribution',
diff --git a/src/gemdat/plots/matplotlib/_autocorrelation.py b/src/gemdat/plots/matplotlib/_autocorrelation.py
new file mode 100644
index 00000000..1d9770c7
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_autocorrelation.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from gemdat.orientations import (
+ Orientations, )
+
+
+def autocorrelation(
+ *,
+ orientations: Orientations,
+ show_traces: bool = True,
+) -> plt.Figure:
+ """Plot the autocorrelation function of the unit vectors series.
+
+ Parameters
+ ----------
+ orientations : Orientations
+ The unit vector trajectories
+ show_traces : bool
+ If True, show traces of individual trajectories
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ ac = orientations.autocorrelation()
+ ac_std = ac.std(axis=0)
+ ac_mean = ac.mean(axis=0)
+
+ # Since we want to plot in picosecond, we convert the time units
+ time_ps = orientations._time_step * 1e12
+ tgrid = np.arange(ac_mean.shape[0]) * time_ps
+
+ fig, ax = plt.subplots()
+
+ ax.plot(tgrid, ac_mean, label='FFT autocorrelation')
+
+ last_color = ax.lines[-1].get_color()
+
+ if show_traces:
+ for i, ac_i in enumerate(ac):
+ label = 'Trajectories' if (i == 0) else None
+ ax.plot(tgrid, ac_i, lw=0.1, c=last_color, label=label)
+
+ ax.fill_between(tgrid,
+ ac_mean - ac_std,
+ ac_mean + ac_std,
+ alpha=0.2,
+ label='Standard deviation')
+ ax.set_xlabel('Time lag (ps)')
+ ax.set_ylabel('Autocorrelation')
+ ax.legend()
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_bond_length_distribution.py b/src/gemdat/plots/matplotlib/_bond_length_distribution.py
new file mode 100644
index 00000000..94f62812
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_bond_length_distribution.py
@@ -0,0 +1,65 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.optimize import curve_fit
+from scipy.stats import skewnorm
+
+from gemdat.orientations import Orientations
+
+
+def bond_length_distribution(*,
+ orientations: Orientations,
+ bins: int = 1000) -> plt.Figure:
+ """Plot the bond length probability distribution.
+
+ Parameters
+ ----------
+ orientations : Orientations
+ The unit vector trajectories
+ bins : int, optional
+ The number of bins, by default 1000
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ *_, bond_lengths = orientations.vectors_spherical.T
+ bond_lengths = bond_lengths.flatten()
+
+ fig, ax = plt.subplots()
+
+ # Plot the normalized histogram
+ hist, edges = np.histogram(bond_lengths, bins=bins, density=True)
+ bin_centers = (edges[:-1] + edges[1:]) / 2
+
+ # Fit a skewed Gaussian distribution to the orientations
+ params, covariance = curve_fit(skewnorm.pdf,
+ bin_centers,
+ hist,
+ p0=[1.5, 1, 1.5])
+
+ # Create a new function using the fitted parameters
+ def _skewnorm_fit(x):
+ return skewnorm.pdf(x, *params)
+
+ # Plot the histogram
+ ax.hist(bond_lengths,
+ bins=bins,
+ density=True,
+ color='blue',
+ alpha=0.7,
+ label='Data')
+
+ # Plot the fitted skewed Gaussian distribution
+ x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000)
+ ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit')
+
+ ax.set_xlabel('Bond length (Å)')
+ ax.set_ylabel(r'Probability density (Å$^{-1}$)')
+ ax.set_title('Bond Length Probability Distribution')
+ ax.legend()
+ ax.grid(True)
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_collective_jumps.py b/src/gemdat/plots/matplotlib/_collective_jumps.py
new file mode 100644
index 00000000..cebb8418
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_collective_jumps.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import matplotlib.pyplot as plt
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def collective_jumps(*, jumps: Jumps) -> plt.Figure:
+ """Plot collective jumps per jump-type combination.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ fig, ax = plt.subplots()
+
+ collective = jumps.collective()
+
+ matrix = collective.site_pair_count_matrix()
+
+ im = ax.imshow(matrix)
+
+ labels = collective.site_pair_count_matrix_labels()
+ ticks = range(len(labels))
+
+ ax.set_xticks(ticks, labels=labels, rotation=90)
+ ax.set_yticks(ticks, labels=labels)
+
+ fig.colorbar(im, ax=ax)
+
+ ax.set(title='Cooperative jumps per jump-type combination')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_displacement_histogram.py b/src/gemdat/plots/matplotlib/_displacement_histogram.py
new file mode 100644
index 00000000..af214536
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_displacement_histogram.py
@@ -0,0 +1,27 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+
+from gemdat.trajectory import Trajectory
+
+
+def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
+ """Plot histogram of total displacement at final timestep.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ fig, ax = plt.subplots()
+ ax.hist(trajectory.distances_from_base_position()[:, -1])
+ ax.set(title='Displacement per element',
+ xlabel='Displacement (Å)',
+ ylabel='Nr. of atoms')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_displacement_per_atom.py b/src/gemdat/plots/matplotlib/_displacement_per_atom.py
new file mode 100644
index 00000000..2b33ae76
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_displacement_per_atom.py
@@ -0,0 +1,30 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+
+from gemdat.trajectory import Trajectory
+
+
+def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
+ """Plot displacement per atom.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ fig, ax = plt.subplots()
+
+ for distances in trajectory.distances_from_base_position():
+ ax.plot(distances, lw=0.3)
+
+ ax.set(title='Displacement per site',
+ xlabel='Time step',
+ ylabel='Displacement (Å)')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_displacement_per_element.py b/src/gemdat/plots/matplotlib/_displacement_per_element.py
new file mode 100644
index 00000000..626c7248
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_displacement_per_element.py
@@ -0,0 +1,35 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+
+from gemdat.trajectory import Trajectory
+
+from gemdat.plots._shared import _mean_displacements_per_element
+
+
+def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
+ """Plot displacement per element.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ displacements = _mean_displacements_per_element(trajectory)
+
+ fig, ax = plt.subplots()
+
+ for symbol, (mean, _) in displacements.items():
+ ax.plot(mean, lw=0.3, label=symbol)
+
+ ax.legend()
+ ax.set(title='Displacement per element',
+ xlabel='Time step',
+ ylabel='Displacement (Å)')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_displacements.py b/src/gemdat/plots/matplotlib/_displacements.py
deleted file mode 100644
index 6508db4c..00000000
--- a/src/gemdat/plots/matplotlib/_displacements.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from __future__ import annotations
-
-from collections import defaultdict
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-from gemdat.trajectory import Trajectory
-
-
-def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
- """Plot displacement per atom.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- fig, ax = plt.subplots()
-
- for distances in trajectory.distances_from_base_position():
- ax.plot(distances, lw=0.3)
-
- ax.set(title='Displacement per site',
- xlabel='Time step',
- ylabel='Displacement (Angstrom)')
-
- return fig
-
-
-def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
- """Plot displacement per element.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- grouped = defaultdict(list)
-
- species = trajectory.species
-
- for sp, distances in zip(species,
- trajectory.distances_from_base_position()):
- grouped[sp.symbol].append(distances)
-
- fig, ax = plt.subplots()
-
- for symbol, distances in grouped.items():
- mean_disp = np.mean(distances, axis=0)
- ax.plot(mean_disp, lw=0.3, label=symbol)
-
- ax.legend()
- ax.set(title='Displacement per element',
- xlabel='Time step',
- ylabel='Displacement (Angstrom)')
-
- return fig
-
-
-def msd_per_element(
- *,
- trajectory: Trajectory,
-) -> plt.Figure:
- """Plot mean squared displacement per element.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- species = list(set(trajectory.species))
-
- fig, ax = plt.subplots()
-
- # Since we want to plot in picosecond, we convert the time units
- time_ps = trajectory.time_step * 1e12
-
- for sp in species:
- traj = trajectory.filter(sp.symbol)
- msd = traj.mean_squared_displacement()
- msd_mean = np.mean(msd, axis=0)
- msd_std = np.std(msd, axis=0)
- t_values = np.arange(len(msd_mean)) * time_ps
- ax.plot(t_values, msd_mean, lw=0.5, label=sp.symbol)
- last_color = ax.lines[-1].get_color()
- ax.fill_between(t_values,
- msd_mean - msd_std,
- msd_mean + msd_std,
- color=last_color,
- alpha=0.2)
-
- ax.legend()
- ax.set(title='Mean squared displacement per element',
- xlabel='Time lag [ps]',
- ylabel='MSD (Angstrom$^2$)')
-
- return fig
-
-
-def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
- """Plot histogram of total displacement at final timestep.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- fig, ax = plt.subplots()
- ax.hist(trajectory.distances_from_base_position()[:, -1])
- ax.set(title='Histogram of displacements',
- xlabel='Displacement (Angstrom)',
- ylabel='Nr. of atoms')
-
- return fig
diff --git a/src/gemdat/plots/matplotlib/_paths.py b/src/gemdat/plots/matplotlib/_energy_along_path.py
similarity index 75%
rename from src/gemdat/plots/matplotlib/_paths.py
rename to src/gemdat/plots/matplotlib/_energy_along_path.py
index bb40fa0a..370248c2 100644
--- a/src/gemdat/plots/matplotlib/_paths.py
+++ b/src/gemdat/plots/matplotlib/_energy_along_path.py
@@ -32,7 +32,7 @@ def energy_along_path(
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(path.energy, marker='o', color='r', label='Optimal path')
- ax.set(ylabel='Free energy [eV]')
+ ax.set(ylabel='Free energy (eV)')
nearest_sites = path.path_over_structure(structure)
@@ -98,43 +98,3 @@ def energy_along_path(
ax.legend(fontsize=8)
return fig
-
-
-def path_on_grid(path: Pathway) -> plt.Figure:
- """Plot the 3d coordinates of the points that define a path.
-
- Parameters
- ----------
- path : Pathway
- Pathway to plot
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- # Create a colormap to visualize the path
- colormap = plt.get_cmap()
- normalize = plt.Normalize(0, len(path.energy))
-
- fig, ax = plt.subplots()
- ax = fig.add_subplot(111, projection='3d')
-
- path_x, path_y, path_z = zip(*path.sites)
-
- for i in range(len(path.energy) - 1):
- ax.plot(path_x[i:i + 1],
- path_y[i:i + 1],
- path_z[i:i + 1],
- color=colormap(normalize(i)),
- marker='o',
- linestyle='-')
-
- ax.set_xlabel('X')
- ax.set_ylabel('Y')
- sm = plt.cm.ScalarMappable(cmap=colormap, norm=normalize)
- sm.set_array([])
- cbar = plt.colorbar(sm, ax=ax)
- cbar.set_label('Steps')
-
- return fig
diff --git a/src/gemdat/plots/matplotlib/_vibration.py b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
similarity index 66%
rename from src/gemdat/plots/matplotlib/_vibration.py
rename to src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
index e5e2bb7e..c89531f2 100644
--- a/src/gemdat/plots/matplotlib/_vibration.py
+++ b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
@@ -2,7 +2,6 @@
import matplotlib.pyplot as plt
import numpy as np
-from scipy import stats
from gemdat.simulation_metrics import SimulationMetrics
from gemdat.trajectory import Trajectory
@@ -62,32 +61,3 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
ax.set_xlim([-0.1e13, 2.5e13])
return fig
-
-
-def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
- """Plot histogram of vibrational amplitudes with fitted Gaussian.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- metrics = SimulationMetrics(trajectory)
-
- fig, ax = plt.subplots()
- ax.hist(metrics.amplitudes(), bins=100, density=True)
-
- x = np.linspace(-2, 2, 100)
- y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude())
- ax.plot(x, y_gauss, 'r')
-
- ax.set(title='Histogram of vibrational amplitudes with fitted Gaussian',
- xlabel='Amplitude (Angstrom)',
- ylabel='Occurrence (a.u.)')
-
- return fig
diff --git a/src/gemdat/plots/matplotlib/_jumps.py b/src/gemdat/plots/matplotlib/_jumps.py
deleted file mode 100644
index f5db9306..00000000
--- a/src/gemdat/plots/matplotlib/_jumps.py
+++ /dev/null
@@ -1,347 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import matplotlib.animation as animation
-import matplotlib.pyplot as plt
-import numpy as np
-from matplotlib import colormaps
-from pymatgen.electronic_structure import plotter
-
-if TYPE_CHECKING:
-
- from gemdat import Jumps
-
-
-def jumps_vs_distance(
- *,
- jumps: Jumps,
- jump_res: float = 0.1,
-) -> plt.Figure:
- """Plot jumps vs. distance histogram.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
- jump_res : float, optional
- Resolution of the bins in Angstrom
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- sites = jumps.sites
-
- trajectory = jumps.trajectory
- lattice = trajectory.get_lattice()
-
- pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
-
- bin_max = (1 + pdist.max() // jump_res) * jump_res
- n_bins = int(bin_max / jump_res) + 1
- x = np.linspace(0, bin_max, n_bins)
- counts = np.zeros_like(x)
-
- bin_idx = np.digitize(pdist, bins=x)
- for idx, n in zip(bin_idx.flatten(), jumps.matrix().flatten()):
- counts[idx] += n
-
- fig, ax = plt.subplots()
-
- ax.bar(x, counts, width=(jump_res * 0.8))
-
- ax.set(title='Jumps vs. Distance',
- xlabel='Distance (Angstrom)',
- ylabel='Number of jumps')
-
- return fig
-
-
-def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure:
- """Plot jumps vs. time histogram.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
- binsize : int, optional
- Width of each bin in number of time steps
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
-
- trajectory = jumps.trajectory
-
- n_steps = len(trajectory)
- bins = np.arange(0, n_steps + binsize, binsize)
-
- fig, ax = plt.subplots()
-
- ax.hist(jumps.data['stop time'], bins=bins, width=0.8 * binsize)
-
- ax.set(title='Jumps vs. time',
- xlabel='Time (steps)',
- ylabel='Number of jumps')
-
- return fig
-
-
-def collective_jumps(*, jumps: Jumps) -> plt.Figure:
- """Plot collective jumps per jump-type combination.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- fig, ax = plt.subplots()
- matrix = jumps.collective().site_pair_count_matrix()
- labels = jumps.collective().site_pair_count_matrix_labels()
-
- mat = ax.imshow(matrix)
-
- ticks = range(len(labels))
-
- ax.set_xticks(ticks, labels=labels, rotation=90)
- ax.set_yticks(ticks, labels=labels)
-
- fig.colorbar(mat, ax=ax)
-
- ax.set(title='Cooperative jumps per jump-type combination')
-
- return fig
-
-
-def jumps_3d(*, jumps: Jumps) -> plt.Figure:
- """Plot jumps in 3D.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- trajectory = jumps.trajectory
- sites = jumps.sites
-
- class LabelItems:
-
- def __init__(self, labels, coords):
- self.labels = labels
- self.coords = coords
-
- def items(self):
- yield from zip(self.labels, self.coords)
-
- coords = sites.frac_coords
- lattice = trajectory.get_lattice()
-
- fig = plt.figure()
- ax = fig.add_subplot(projection='3d')
-
- site_labels = LabelItems(jumps.sites.labels, coords)
-
- xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
- [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])
-
- plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)
- plotter.plot_labels(xyz_labels,
- lattice=lattice,
- ax=ax,
- color='green',
- size=12)
- plotter.plot_points(coords, lattice=lattice, ax=ax)
-
- for i, j in zip(*np.triu_indices(len(coords), k=1)):
- count = jumps.matrix()[i, j] + jumps.matrix()[j, i]
- if count == 0:
- continue
-
- coord_i = coords[i]
- coord_j = coords[j]
-
- lw = 1 + np.log(count)
-
- length, image = lattice.get_distance_and_image(coord_i, coord_j)
-
- # NOTE: might need to plot `line = [coord_i - image, coord_j]` as well
- if np.any(image != 0):
- lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)]
- else:
- lines = [(coord_i, coord_j)]
-
- for line in lines:
- plotter.plot_path(line,
- lattice=lattice,
- ax=ax,
- color='red',
- linewidth=lw)
-
- plotter.plot_labels(site_labels,
- lattice=lattice,
- ax=ax,
- color='black',
- size=8)
-
- ax.set(
- title='Jumps between sites',
- xlabel="x' (ang)",
- ylabel="y' (ang)",
- zlabel="z' (ang)",
- )
-
- ax.set_aspect('equal') # only auto is supported
-
- return fig
-
-
-def jumps_3d_animation(
- *,
- jumps: Jumps,
- t_start: int,
- t_stop: int,
- decay: float = 0.05,
- skip: int = 5,
- interval: int = 20,
-) -> animation.FuncAnimation:
- """Plot jumps in 3D as an animation over time.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
- t_start : int
- Time step to start animation (relative to equilibration time)
- t_stop : int
- Time step to stop animation (relative to equilibration time)
- decay : float, optional
- Controls the decay of the line width (higher = faster decay)
- skip : float, optional
- Skip frames (increase for faster, but less accurate rendering)
- interval : int, optional
- Delay between frames in milliseconds.
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- minwidth = 0.2
- maxwidth = 5.0
-
- trajectory = jumps.trajectory
-
- class LabelItems:
-
- def __init__(self, labels, coords):
- self.labels = labels
- self.coords = coords
-
- def items(self):
- yield from zip(self.labels, self.coords)
-
- coords = jumps.sites.frac_coords
- lattice = trajectory.get_lattice()
-
- color_from = colormaps['Set1'].colors # type: ignore
- color_to = colormaps['Pastel1'].colors # type: ignore
-
- fig = plt.figure()
- ax = fig.add_subplot(projection='3d')
-
- xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
- [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])
-
- plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)
-
- plotter.plot_labels(xyz_labels,
- lattice=lattice,
- ax=ax,
- color='green',
- size=12)
-
- assert len(ax.collections) == 0
- plotter.plot_points(coords,
- lattice=lattice,
- ax=ax,
- s=50,
- color='white',
- edgecolor='black')
- points = ax.collections
-
- events = jumps.data.sort_values('start time', ignore_index=True)
-
- for _, event in events.iterrows():
- site_i = event['start site']
- site_j = event['destination site']
-
- coord_i = coords[site_i]
- coord_j = coords[site_j]
-
- lw = 0
-
- _, image = lattice.get_distance_and_image(coord_i, coord_j)
-
- line = [coord_i, coord_j + image]
-
- plotter.plot_path(line,
- lattice=lattice,
- ax=ax,
- color='red',
- linewidth=lw)
-
- lines = ax.lines[3:]
-
- ax.set(
- title='Jumps between sites',
- xlabel="x' (ang)",
- ylabel="y' (ang)",
- zlabel="z' (ang)",
- )
-
- ax.set_aspect('equal') # only auto is supported
-
- def update(frame_no):
- t_frame = t_start + (frame_no * skip)
-
- for i, event in events.iterrows():
-
- if event['start time'] > t_frame:
- break
-
- lw = max(maxwidth - decay * (t_frame - event['start time']),
- minwidth)
-
- line = lines[i]
- line.set_color('red')
- line.set_linewidth(lw)
-
- points[event['start site']].set_facecolor(
- color_from[event['atom index'] % len(color_from)])
- points[event['destination site']].set_facecolor(
- color_to[event['atom index'] % len(color_to)])
-
- start_time = event['start time']
- ax.set_title(f'T: {t_frame} | Next jump: {start_time}')
-
- n_frames = int((t_stop - t_start) / skip)
-
- return animation.FuncAnimation(fig=fig,
- func=update,
- frames=n_frames,
- interval=interval,
- repeat=False)
diff --git a/src/gemdat/plots/matplotlib/_jumps_3d.py b/src/gemdat/plots/matplotlib/_jumps_3d.py
new file mode 100644
index 00000000..4a540519
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_jumps_3d.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import matplotlib.pyplot as plt
+import numpy as np
+from pymatgen.electronic_structure import plotter
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_3d(*, jumps: Jumps) -> plt.Figure:
+ """Plot jumps in 3D.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ trajectory = jumps.trajectory
+ sites = jumps.sites
+
+ class LabelItems:
+
+ def __init__(self, labels, coords):
+ self.labels = labels
+ self.coords = coords
+
+ def items(self):
+ yield from zip(self.labels, self.coords)
+
+ coords = sites.frac_coords
+ lattice = trajectory.get_lattice()
+
+ fig = plt.figure()
+ ax = fig.add_subplot(projection='3d')
+
+ site_labels = LabelItems(jumps.sites.labels, coords)
+
+ xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
+ [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])
+
+ plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)
+ plotter.plot_labels(xyz_labels,
+ lattice=lattice,
+ ax=ax,
+ color='green',
+ size=12)
+ plotter.plot_points(coords, lattice=lattice, ax=ax)
+
+ for i, j in zip(*np.triu_indices(len(coords), k=1)):
+ count = jumps.matrix()[i, j] + jumps.matrix()[j, i]
+ if count == 0:
+ continue
+
+ coord_i = coords[i]
+ coord_j = coords[j]
+
+ lw = 1 + np.log(count)
+
+ length, image = lattice.get_distance_and_image(coord_i, coord_j)
+
+ # NOTE: might need to plot `line = [coord_i - image, coord_j]` as well
+ if np.any(image != 0):
+ lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)]
+ else:
+ lines = [(coord_i, coord_j)]
+
+ for line in lines:
+ plotter.plot_path(line,
+ lattice=lattice,
+ ax=ax,
+ color='red',
+ linewidth=lw)
+
+ plotter.plot_labels(site_labels,
+ lattice=lattice,
+ ax=ax,
+ color='black',
+ size=8)
+
+ ax.set(
+ title='Jumps between sites',
+ xlabel="x' (Å)",
+ ylabel="y' (Å)",
+ zlabel="z' (Å)",
+ )
+
+ ax.set_aspect('equal') # only auto is supported
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_jumps_3d_animation.py b/src/gemdat/plots/matplotlib/_jumps_3d_animation.py
new file mode 100644
index 00000000..74031e95
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_jumps_3d_animation.py
@@ -0,0 +1,149 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import matplotlib.animation as animation
+import matplotlib.pyplot as plt
+from matplotlib import colormaps
+from pymatgen.electronic_structure import plotter
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_3d_animation(
+ *,
+ jumps: Jumps,
+ t_start: int,
+ t_stop: int,
+ decay: float = 0.05,
+ skip: int = 5,
+ interval: int = 20,
+) -> animation.FuncAnimation:
+ """Plot jumps in 3D as an animation over time.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+ t_start : int
+ Time step to start animation (relative to equilibration time)
+ t_stop : int
+ Time step to stop animation (relative to equilibration time)
+ decay : float, optional
+ Controls the decay of the line width (higher = faster decay)
+ skip : float, optional
+ Skip frames (increase for faster, but less accurate rendering)
+ interval : int, optional
+ Delay between frames in milliseconds.
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ minwidth = 0.2
+ maxwidth = 5.0
+
+ trajectory = jumps.trajectory
+
+ class LabelItems:
+
+ def __init__(self, labels, coords):
+ self.labels = labels
+ self.coords = coords
+
+ def items(self):
+ yield from zip(self.labels, self.coords)
+
+ coords = jumps.sites.frac_coords
+ lattice = trajectory.get_lattice()
+
+ color_from = colormaps['Set1'].colors # type: ignore
+ color_to = colormaps['Pastel1'].colors # type: ignore
+
+ fig = plt.figure()
+ ax = fig.add_subplot(projection='3d')
+
+ xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1],
+ [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]])
+
+ plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1)
+
+ plotter.plot_labels(xyz_labels,
+ lattice=lattice,
+ ax=ax,
+ color='green',
+ size=12)
+
+ assert len(ax.collections) == 0
+ plotter.plot_points(coords,
+ lattice=lattice,
+ ax=ax,
+ s=50,
+ color='white',
+ edgecolor='black')
+ points = ax.collections
+
+ events = jumps.data.sort_values('start time', ignore_index=True)
+
+ for _, event in events.iterrows():
+ site_i = event['start site']
+ site_j = event['destination site']
+
+ coord_i = coords[site_i]
+ coord_j = coords[site_j]
+
+ lw = 0
+
+ _, image = lattice.get_distance_and_image(coord_i, coord_j)
+
+ line = [coord_i, coord_j + image]
+
+ plotter.plot_path(line,
+ lattice=lattice,
+ ax=ax,
+ color='red',
+ linewidth=lw)
+
+ lines = ax.lines[3:]
+
+ ax.set(
+ title='Jumps between sites',
+ xlabel="x' (Å)",
+ ylabel="y' (Å)",
+ zlabel="z' (Å)",
+ )
+
+ ax.set_aspect('equal') # only auto is supported
+
+ def update(frame_no):
+ t_frame = t_start + (frame_no * skip)
+
+ for i, event in events.iterrows():
+
+ if event['start time'] > t_frame:
+ break
+
+ lw = max(maxwidth - decay * (t_frame - event['start time']),
+ minwidth)
+
+ line = lines[i]
+ line.set_color('red')
+ line.set_linewidth(lw)
+
+ points[event['start site']].set_facecolor(
+ color_from[event['atom index'] % len(color_from)])
+ points[event['destination site']].set_facecolor(
+ color_to[event['atom index'] % len(color_to)])
+
+ start_time = event['start time']
+ ax.set_title(f'T: {t_frame} | Next jump: {start_time}')
+
+ n_frames = int((t_stop - t_start) / skip)
+
+ return animation.FuncAnimation(fig=fig,
+ func=update,
+ frames=n_frames,
+ interval=interval,
+ repeat=False)
diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_distance.py b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py
new file mode 100644
index 00000000..ba6b1057
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_vs_distance(
+ *,
+ jumps: Jumps,
+ jump_res: float = 0.1,
+) -> plt.Figure:
+ """Plot jumps vs. distance histogram.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+ jump_res : float, optional
+ Resolution of the bins in Angstrom
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ sites = jumps.sites
+
+ trajectory = jumps.trajectory
+ lattice = trajectory.get_lattice()
+
+ pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
+
+ bin_max = (1 + pdist.max() // jump_res) * jump_res
+ n_bins = int(bin_max / jump_res) + 1
+ x = np.linspace(0, bin_max, n_bins)
+ counts = np.zeros_like(x)
+
+ bin_idx = np.digitize(pdist, bins=x)
+ for idx, n in zip(bin_idx.flatten(), jumps.matrix().flatten()):
+ counts[idx] += n
+
+ fig, ax = plt.subplots()
+
+ ax.bar(x, counts, width=(jump_res * 0.8))
+
+ ax.set(title='Jumps vs. Distance',
+ xlabel='Distance (Å)',
+ ylabel='Number of jumps')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_time.py b/src/gemdat/plots/matplotlib/_jumps_vs_time.py
new file mode 100644
index 00000000..d0ac2a21
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_jumps_vs_time.py
@@ -0,0 +1,41 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure:
+ """Plot jumps vs. time histogram.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+ binsize : int, optional
+ Width of each bin in number of time steps
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+
+ trajectory = jumps.trajectory
+
+ n_steps = len(trajectory)
+ bins = np.arange(0, n_steps + binsize, binsize)
+
+ fig, ax = plt.subplots()
+
+ ax.hist(jumps.data['stop time'], bins=bins, width=0.8 * binsize)
+
+ ax.set(title='Jumps vs. time',
+ xlabel='Time (steps)',
+ ylabel='Number of jumps')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_msd_per_element.py b/src/gemdat/plots/matplotlib/_msd_per_element.py
new file mode 100644
index 00000000..f60d2b3e
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_msd_per_element.py
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from gemdat.trajectory import Trajectory
+
+
+def msd_per_element(
+ *,
+ trajectory: Trajectory,
+ show_traces: bool = True,
+) -> plt.Figure:
+ """Plot mean squared displacement per element.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory
+ show_traces : bool
+ If True, show individual traces for each element
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ species = list(set(trajectory.species))
+
+ fig, ax = plt.subplots()
+
+ # Since we want to plot in picosecond, we convert the time units
+ time_ps = trajectory.time_step * 1e12
+ t_values = np.arange(len(trajectory)) * time_ps
+
+ for sp in species:
+ traj = trajectory.filter(sp.symbol)
+ msd = traj.mean_squared_displacement()
+
+ msd_mean = np.mean(msd, axis=0)
+ msd_std = np.std(msd, axis=0)
+
+ ax.plot(t_values, msd_mean, lw=0.5, label=f'{sp.symbol} mean')
+
+ last_color = ax.lines[-1].get_color()
+
+ if show_traces:
+ for i, traj in enumerate(msd):
+ label = f'{sp.symbol} trajectories' if (i == 0) else None
+ ax.plot(t_values, traj, lw=0.1, c=last_color, label=label)
+
+ ax.fill_between(t_values,
+ msd_mean - msd_std,
+ msd_mean + msd_std,
+ color=last_color,
+ alpha=0.2,
+ label=f'{sp.symbol} std')
+
+ ax.legend()
+ ax.set(title='Mean squared displacement per element',
+ xlabel='Time lag (ps)',
+ ylabel='MSD (Å$^2$)')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_orientations.py b/src/gemdat/plots/matplotlib/_orientations.py
deleted file mode 100644
index e1647bea..00000000
--- a/src/gemdat/plots/matplotlib/_orientations.py
+++ /dev/null
@@ -1,164 +0,0 @@
-from __future__ import annotations
-
-import matplotlib.pyplot as plt
-import numpy as np
-from scipy.optimize import curve_fit
-from scipy.stats import skewnorm
-
-from gemdat.orientations import (
- Orientations,
- calculate_spherical_areas,
-)
-
-
-def rectilinear(*,
- orientations: Orientations,
- shape: tuple[int, int] = (90, 360),
- normalize_histo: bool = True) -> plt.Figure:
- """Plot a rectilinear projection of a spherical function. This function
- uses the transformed trajectory.
-
- Parameters
- ----------
- orientations : Orientations
- The unit vector trajectories
- shape : tuple
- The shape of the spherical sector in which the trajectory is plotted
- normalize_histo : bool, optional
- If True, normalize the histogram by the area of the bins, by default True
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- # Convert the vectors to spherical coordinates
- az, el, _ = orientations.vectors_spherical.T
- az = az.flatten()
- el = el.flatten()
-
- hist, xedges, yedges = np.histogram2d(el, az, shape)
-
- if normalize_histo:
- # Normalize by the area of the bins
- areas = calculate_spherical_areas(shape)
- hist = np.divide(hist, areas)
- # Drop the bins at the poles where normalization is not possible
- hist = hist[1:-1, :]
-
- values = hist.T
- axis_phi, axis_theta = values.shape
-
- phi = np.linspace(0, 360, axis_phi)
- theta = np.linspace(0, 180, axis_theta)
-
- theta, phi = np.meshgrid(theta, phi)
-
- fig, ax = plt.subplots(subplot_kw=dict(projection='rectilinear'))
- cs = ax.contourf(phi, theta, values, cmap='viridis')
- ax.set_yticks(np.arange(0, 190, 45))
- ax.set_xticks(np.arange(0, 370, 45))
-
- ax.set_xlabel(r'azimuthal angle φ $[\degree$]')
- ax.set_ylabel(r'elevation θ $[\degree$]')
-
- ax.grid(visible=True)
- cbar = fig.colorbar(cs, label='areal probability', format='')
-
- # Rotate the colorbar label by 180 degrees
- cbar.ax.yaxis.set_label_coords(2.5,
- 0.5) # Adjust the position of the label
- cbar.set_label('areal probability', rotation=270, labelpad=15)
- return fig
-
-
-def bond_length_distribution(*,
- orientations: Orientations,
- bins: int = 1000) -> plt.Figure:
- """Plot the bond length probability distribution.
-
- Parameters
- ----------
- orientations : Orientations
- The unit vector trajectories
- bins : int, optional
- The number of bins, by default 1000
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- *_, bond_lengths = orientations.vectors_spherical.T
- bond_lengths = bond_lengths.flatten()
-
- fig, ax = plt.subplots()
-
- # Plot the normalized histogram
- hist, edges = np.histogram(bond_lengths, bins=bins, density=True)
- bin_centers = (edges[:-1] + edges[1:]) / 2
-
- # Fit a skewed Gaussian distribution to the orientations
- params, covariance = curve_fit(skewnorm.pdf,
- bin_centers,
- hist,
- p0=[1.5, 1, 1.5])
-
- # Create a new function using the fitted parameters
- def _skewnorm_fit(x):
- return skewnorm.pdf(x, *params)
-
- # Plot the histogram
- ax.hist(bond_lengths,
- bins=bins,
- density=True,
- color='blue',
- alpha=0.7,
- label='Data')
-
- # Plot the fitted skewed Gaussian distribution
- x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000)
- ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit')
-
- ax.set_xlabel(r'Bond length $[\AA]$')
- ax.set_ylabel(r'Probability density $[\AA^{-1}]$')
- ax.set_title('Bond Length Probability Distribution')
- ax.legend()
- ax.grid(True)
-
- return fig
-
-
-def autocorrelation(
- *,
- orientations: Orientations,
-) -> plt.Figure:
- """Plot the autocorrelation function of the unit vectors series.
-
- Parameters
- ----------
- orientations : Orientations
- The unit vector trajectories
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- ac = orientations.autocorrelation()
- ac_std = ac.std(axis=0)
- ac_mean = ac.mean(axis=0)
-
- # Since we want to plot in picosecond, we convert the time units
- time_ps = orientations._time_step * 1e12
- tgrid = np.arange(ac_mean.shape[0]) * time_ps
-
- # and now we can plot the autocorrelation function
- fig, ax = plt.subplots()
-
- ax.plot(tgrid, ac_mean, label='FFT-Autocorrelation')
- ax.fill_between(tgrid, ac_mean - ac_std, ac_mean + ac_std, alpha=0.2)
- ax.set_xlabel('Time lag [ps]')
- ax.set_ylabel('Autocorrelation')
-
- return fig
diff --git a/src/gemdat/plots/matplotlib/_path_on_grid.py b/src/gemdat/plots/matplotlib/_path_on_grid.py
new file mode 100644
index 00000000..1412942c
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_path_on_grid.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+
+from gemdat.path import Pathway
+
+
+def path_on_grid(path: Pathway) -> plt.Figure:
+ """Plot the 3d coordinates of the points that define a path.
+
+ Parameters
+ ----------
+ path : Pathway
+ Pathway to plot
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ # Create a colormap to visualize the path
+ colormap = plt.get_cmap()
+ normalize = plt.Normalize(0, len(path.energy))
+
+ fig, ax = plt.subplots()
+ ax = fig.add_subplot(111, projection='3d')
+
+ path_x, path_y, path_z = zip(*path.sites)
+
+ for i in range(len(path.energy) - 1):
+ ax.plot(path_x[i:i + 1],
+ path_y[i:i + 1],
+ path_z[i:i + 1],
+ color=colormap(normalize(i)),
+ marker='o',
+ linestyle='-')
+
+ ax.set_xlabel('X')
+ ax.set_ylabel('Y')
+ sm = plt.cm.ScalarMappable(cmap=colormap, norm=normalize)
+ sm.set_array([])
+ cbar = plt.colorbar(sm, ax=ax)
+ cbar.set_label('Steps')
+
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_rdf.py b/src/gemdat/plots/matplotlib/_radial_distribution.py
similarity index 95%
rename from src/gemdat/plots/matplotlib/_rdf.py
rename to src/gemdat/plots/matplotlib/_radial_distribution.py
index 22564e7f..fc75d37b 100644
--- a/src/gemdat/plots/matplotlib/_rdf.py
+++ b/src/gemdat/plots/matplotlib/_radial_distribution.py
@@ -30,7 +30,7 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure:
ax.legend()
ax.set(title=f'Radial distribution function ({states})',
- xlabel='Distance (Ang)',
+ xlabel='Distance (Å)',
ylabel='Counts')
return fig
diff --git a/src/gemdat/plots/matplotlib/_rectilinear.py b/src/gemdat/plots/matplotlib/_rectilinear.py
new file mode 100644
index 00000000..fd12001b
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_rectilinear.py
@@ -0,0 +1,70 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from gemdat.orientations import (
+ Orientations,
+ calculate_spherical_areas,
+)
+
+
+def rectilinear(*,
+ orientations: Orientations,
+ shape: tuple[int, int] = (90, 360),
+ normalize_histo: bool = True) -> plt.Figure:
+ """Plot a rectilinear projection of a spherical function. This function
+ uses the transformed trajectory.
+
+ Parameters
+ ----------
+ orientations : Orientations
+ The unit vector trajectories
+ shape : tuple
+ The shape of the spherical sector in which the trajectory is plotted
+ normalize_histo : bool, optional
+ If True, normalize the histogram by the area of the bins, by default True
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ # Convert the vectors to spherical coordinates
+ az, el, _ = orientations.vectors_spherical.T
+ az = az.flatten()
+ el = el.flatten()
+
+ hist, xedges, yedges = np.histogram2d(el, az, shape)
+
+ if normalize_histo:
+ # Normalize by the area of the bins
+ areas = calculate_spherical_areas(shape)
+ hist = np.divide(hist, areas)
+ # Drop the bins at the poles where normalization is not possible
+ hist = hist[1:-1, :]
+
+ values = hist.T
+ axis_phi, axis_theta = values.shape
+
+ phi = np.linspace(0, 360, axis_phi)
+ theta = np.linspace(0, 180, axis_theta)
+
+ theta, phi = np.meshgrid(theta, phi)
+
+ fig, ax = plt.subplots(subplot_kw=dict(projection='rectilinear'))
+ cs = ax.contourf(phi, theta, values, cmap='viridis')
+ ax.set_yticks(np.arange(0, 190, 45))
+ ax.set_xticks(np.arange(0, 370, 45))
+
+ ax.set_xlabel('Azimuthal angle φ (°)')
+ ax.set_ylabel('Elevation θ (°)')
+
+ ax.grid(visible=True)
+ cbar = fig.colorbar(cs, label='Areal probability', format='')
+
+ # Rotate the colorbar label by 180 degrees
+ cbar.ax.yaxis.set_label_coords(2.5,
+ 0.5) # Adjust the position of the label
+ cbar.set_label('Areal probability', rotation=270, labelpad=15)
+ return fig
diff --git a/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
new file mode 100644
index 00000000..64e12cd2
--- /dev/null
+++ b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy import stats
+
+from gemdat.simulation_metrics import SimulationMetrics
+from gemdat.trajectory import Trajectory
+
+
+def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
+ """Plot histogram of vibrational amplitudes with fitted Gaussian.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ metrics = SimulationMetrics(trajectory)
+
+ fig, ax = plt.subplots()
+ ax.hist(metrics.amplitudes(), bins=100, density=True)
+
+ x = np.linspace(-2, 2, 100)
+ y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude())
+ ax.plot(x, y_gauss, 'r')
+
+ ax.set(title='Histogram of vibrational amplitudes with fitted Gaussian',
+ xlabel='Amplitude (Å)',
+ ylabel='Occurrence (a.u.)')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/__init__.py b/src/gemdat/plots/plotly/__init__.py
index 22fc1156..9f94a00f 100644
--- a/src/gemdat/plots/plotly/__init__.py
+++ b/src/gemdat/plots/plotly/__init__.py
@@ -2,23 +2,17 @@
from __future__ import annotations
from ._density import density
-from ._displacements import (
- displacement_histogram,
- displacement_per_atom,
- displacement_per_element,
- msd_per_element,
-)
-from ._jumps import (
- collective_jumps,
- jumps_3d,
- jumps_vs_distance,
- jumps_vs_time,
-)
+from ._displacement_histogram import displacement_histogram
+from ._displacement_per_atom import displacement_per_atom
+from ._displacement_per_element import displacement_per_element
+from ._msd_per_element import msd_per_element
+from ._collective_jumps import collective_jumps
+from ._jumps_3d import jumps_3d
+from ._jumps_vs_distance import jumps_vs_distance
+from ._jumps_vs_time import jumps_vs_time
from ._plot3d import plot_3d
-from ._vibration import (
- frequency_vs_occurence,
- vibrational_amplitudes,
-)
+from ._frequency_vs_occurence import frequency_vs_occurence
+from ._vibrational_amplitudes import vibrational_amplitudes
__all__ = [
'collective_jumps',
diff --git a/src/gemdat/plots/plotly/_collective_jumps.py b/src/gemdat/plots/plotly/_collective_jumps.py
new file mode 100644
index 00000000..6bbc310a
--- /dev/null
+++ b/src/gemdat/plots/plotly/_collective_jumps.py
@@ -0,0 +1,46 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import plotly.express as px
+import plotly.graph_objects as go
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def collective_jumps(*, jumps: Jumps) -> go.Figure:
+ """Plot collective jumps per jump-type combination.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+
+ Returns
+ -------
+ fig : plotly.graph_objects.Figure
+ Output figure
+ """
+ collective = jumps.collective()
+ matrix = collective.site_pair_count_matrix()
+
+ fig = px.imshow(matrix)
+
+ labels = collective.site_pair_count_matrix_labels()
+
+ ticks = list(range(len(labels)))
+
+ fig.update_layout(xaxis={
+ 'tickmode': 'array',
+ 'tickvals': ticks,
+ 'ticktext': labels
+ },
+ yaxis={
+ 'tickmode': 'array',
+ 'tickvals': ticks,
+ 'ticktext': labels
+ },
+ title='Cooperative jumps per jump-type combination')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_displacement_histogram.py b/src/gemdat/plots/plotly/_displacement_histogram.py
new file mode 100644
index 00000000..3ea1d608
--- /dev/null
+++ b/src/gemdat/plots/plotly/_displacement_histogram.py
@@ -0,0 +1,92 @@
+from __future__ import annotations
+
+import numpy as np
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+
+from gemdat.trajectory import Trajectory
+
+
+def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame:
+ """_trajectory_to_dataframe.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ trajectory
+
+ Returns
+ -------
+ pd.DataFrame
+ """
+ data = []
+ for specie, distance in zip(
+ trajectory.species,
+ trajectory.distances_from_base_position()[:, -1]):
+ data.append((specie, round(distance)))
+
+ df = pd.DataFrame(columns=['Element', 'Displacement'], data=data)
+ df = df.groupby(['Displacement', 'Element'
+ ]).size().reset_index().rename(columns={0: 'count'})
+ return df
+
+
+def displacement_histogram(trajectory: Trajectory,
+ n_parts: int = 1) -> go.Figure:
+ """Plot histogram of total displacement at final timestep.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+ n_parts : int
+ Plot error bars by dividing data into n parts
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+
+ if n_parts == 1:
+ df = _trajectory_to_dataframe(trajectory)
+
+ fig = px.bar(df,
+ x='Displacement',
+ y='count',
+ color='Element',
+ barmode='stack')
+
+ fig.update_layout(title='Displacement per element',
+ xaxis_title='Displacement (Å)',
+ yaxis_title='Nr. of atoms')
+ else:
+ interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
+ dfs = [
+ _trajectory_to_dataframe(part)
+ for part in trajectory.split(n_parts)
+ ]
+
+ all_df = pd.concat(dfs)
+
+ # Get the mean and standard deviation
+ grouped = all_df.groupby(['Displacement', 'Element'])
+ mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
+ std = grouped.std().reset_index().rename(columns={'count': 'std'})
+ df = mean.merge(std, how='inner')
+
+ fig = px.bar(df,
+ x='Displacement',
+ y='mean',
+ color='Element',
+ error_y='std',
+ barmode='group')
+
+ fig.update_layout(
+ title=
+ f'Displacement per element after {int(interval[1]-interval[0])} timesteps',
+ xaxis_title='Displacement (Å)',
+ yaxis_title='Nr. of atoms')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_displacement_per_atom.py b/src/gemdat/plots/plotly/_displacement_per_atom.py
new file mode 100644
index 00000000..b0213f6b
--- /dev/null
+++ b/src/gemdat/plots/plotly/_displacement_per_atom.py
@@ -0,0 +1,38 @@
+from __future__ import annotations
+
+import plotly.graph_objects as go
+
+from gemdat.trajectory import Trajectory
+
+
+def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
+ """Plot displacement per atom.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+
+ fig = go.Figure()
+
+ distances = [dist for dist in trajectory.distances_from_base_position()]
+
+ for i, distance in enumerate(distances):
+ fig.add_trace(
+ go.Scatter(y=distance,
+ name=i,
+ mode='lines',
+ line={'width': 1},
+ showlegend=False))
+
+ fig.update_layout(title='Displacement per atom',
+ xaxis_title='Time step',
+ yaxis_title='Displacement (Å)')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_displacement_per_element.py b/src/gemdat/plots/plotly/_displacement_per_element.py
new file mode 100644
index 00000000..3b3068c8
--- /dev/null
+++ b/src/gemdat/plots/plotly/_displacement_per_element.py
@@ -0,0 +1,53 @@
+from __future__ import annotations
+
+import plotly.graph_objects as go
+
+from gemdat.trajectory import Trajectory
+from gemdat.plots._shared import _mean_displacements_per_element
+
+
+def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
+ """Plot displacement per element.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ displacements = _mean_displacements_per_element(trajectory)
+
+ fig = go.Figure()
+
+ for symbol, (mean, std) in displacements.items():
+ fig.add_trace(
+ go.Scatter(y=mean,
+ name=symbol + ' + std',
+ mode='lines',
+ line={'width': 3},
+ legendgroup=symbol))
+ fig.add_trace(
+ go.Scatter(y=mean + std,
+ name=symbol + ' + std',
+ mode='lines',
+ line={'width': 0},
+ legendgroup=symbol,
+ showlegend=False))
+ fig.add_trace(
+ go.Scatter(y=mean - std,
+ name=symbol + ' + std',
+ mode='lines',
+ line={'width': 0},
+ legendgroup=symbol,
+ showlegend=False,
+ fill='tonexty'))
+
+ fig.update_layout(title='Displacement per element',
+ xaxis_title='Time step',
+ yaxis_title='Displacement (Å)')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_displacements.py b/src/gemdat/plots/plotly/_displacements.py
deleted file mode 100644
index 0ed56ed6..00000000
--- a/src/gemdat/plots/plotly/_displacements.py
+++ /dev/null
@@ -1,231 +0,0 @@
-from __future__ import annotations
-
-from collections import defaultdict
-
-import numpy as np
-import pandas as pd
-import plotly.express as px
-import plotly.graph_objects as go
-
-from gemdat.trajectory import Trajectory
-
-
-def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure:
- """Plot displacement per atom.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
-
- fig = go.Figure()
-
- distances = [dist for dist in trajectory.distances_from_base_position()]
-
- for i, distance in enumerate(distances):
- fig.add_trace(
- go.Scatter(y=distance,
- name=i,
- mode='lines',
- line={'width': 1},
- showlegend=False))
-
- fig.update_layout(title='Displacement per atom',
- xaxis_title='Time step',
- yaxis_title='Displacement (Angstrom)')
-
- return fig
-
-
-def displacement_per_element(*, trajectory: Trajectory) -> go.Figure:
- """Plot displacement per element.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
-
- fig = go.Figure()
-
- grouped = defaultdict(list)
-
- species = trajectory.species
-
- for sp, distances in zip(species,
- trajectory.distances_from_base_position()):
- grouped[sp.symbol].append(distances)
-
- for symbol, distances in grouped.items():
- mean_disp = np.mean(distances, axis=0)
- std_disp = np.std(distances, axis=0)
- fig.add_trace(
- go.Scatter(y=mean_disp,
- name=symbol + ' + std',
- mode='lines',
- line={'width': 3},
- legendgroup=symbol))
- fig.add_trace(
- go.Scatter(y=mean_disp + std_disp,
- name=symbol + ' + std',
- mode='lines',
- line={'width': 0},
- legendgroup=symbol,
- showlegend=False))
- fig.add_trace(
- go.Scatter(y=mean_disp - std_disp,
- name=symbol + ' + std',
- mode='lines',
- line={'width': 0},
- legendgroup=symbol,
- showlegend=False,
- fill='tonexty'))
-
- fig.update_layout(title='Displacement per element',
- xaxis_title='Time step',
- yaxis_title='Displacement (Angstrom)')
-
- return fig
-
-
-def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
- """Plot mean squared displacement per element.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
-
- fig = go.Figure()
-
- species = list(set(trajectory.species))
-
- # Since we want to plot in picosecond, we convert the time units
- time_ps = trajectory.time_step * 1e12
-
- for sp in species:
- traj = trajectory.filter(sp.symbol)
-
- msd = traj.mean_squared_displacement()
- msd_mean = np.mean(msd, axis=0)
- msd_std = np.std(msd, axis=0)
- t_values = np.arange(len(msd_mean)) * time_ps
-
- fig.add_trace(
- go.Scatter(x=t_values,
- y=msd_mean,
- error_y=dict(type='data',
- array=msd_std,
- width=0.1,
- thickness=0.1),
- name=sp.symbol,
- mode='lines',
- line={'width': 3},
- legendgroup=sp.symbol))
-
- fig.update_layout(title='Mean squared displacement per element',
- xaxis_title='Time lag [ps]',
- yaxis_title='MSD (Angstrom2)')
-
- return fig
-
-
-def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame:
- """_trajectory_to_dataframe.
-
- Parameters
- ----------
- trajectory : Trajectory
- trajectory
-
- Returns
- -------
- pd.DataFrame
- """
- data = []
- for specie, distance in zip(
- trajectory.species,
- trajectory.distances_from_base_position()[:, -1]):
- data.append((specie, round(distance)))
-
- df = pd.DataFrame(columns=['Element', 'Displacement'], data=data)
- df = df.groupby(['Displacement', 'Element'
- ]).size().reset_index().rename(columns={0: 'count'})
- return df
-
-
-def displacement_histogram(trajectory: Trajectory,
- n_parts: int = 1) -> go.Figure:
- """Plot histogram of total displacement at final timestep.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
- n_parts : int
- Plot error bars by dividing data into n parts
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
-
- if n_parts == 1:
- df = _trajectory_to_dataframe(trajectory)
-
- fig = px.bar(df,
- x='Displacement',
- y='count',
- color='Element',
- barmode='stack')
-
- fig.update_layout(title='Displacement per element',
- xaxis_title='Displacement (Angstrom)',
- yaxis_title='Nr. of atoms')
- else:
- interval = np.linspace(0, len(trajectory) - 1, n_parts + 1)
- dfs = [
- _trajectory_to_dataframe(part)
- for part in trajectory.split(n_parts)
- ]
-
- all_df = pd.concat(dfs)
-
- # Get the mean and standard deviation
- grouped = all_df.groupby(['Displacement', 'Element'])
- mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
- std = grouped.std().reset_index().rename(columns={'count': 'std'})
- df = mean.merge(std, how='inner')
-
- fig = px.bar(df,
- x='Displacement',
- y='mean',
- color='Element',
- error_y='std',
- barmode='group')
-
- fig.update_layout(
- title=
- f'Displacement per element after {int(interval[1]-interval[0])} timesteps',
- xaxis_title='Displacement (Angstrom)',
- yaxis_title='Nr. of atoms')
-
- return fig
diff --git a/src/gemdat/plots/plotly/_frequency_vs_occurence.py b/src/gemdat/plots/plotly/_frequency_vs_occurence.py
new file mode 100644
index 00000000..0984ee7f
--- /dev/null
+++ b/src/gemdat/plots/plotly/_frequency_vs_occurence.py
@@ -0,0 +1,78 @@
+from __future__ import annotations
+
+import numpy as np
+import plotly.graph_objects as go
+
+from gemdat.simulation_metrics import SimulationMetrics
+from gemdat.trajectory import Trajectory
+
+
+def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
+ """Plot attempt frequency vs occurence.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory, i.e. for the diffusing atom
+
+ Returns
+ -------
+ fig : go.figure.Figure
+ Output figure
+ """
+ metrics = SimulationMetrics(trajectory)
+ speed = metrics.speed()
+
+ length = speed.shape[1]
+ half_length = length // 2 + 1
+
+ trans = np.fft.fft(speed)
+
+ two_sided = np.abs(trans / length)
+ one_sided = two_sided[:, :half_length]
+
+ fig = go.Figure()
+
+ f = trajectory.sampling_frequency * np.arange(half_length) / length
+
+ sum_freqs = np.sum(one_sided, axis=0)
+ smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
+ fig.add_trace(
+ go.Scatter(y=smoothed,
+ x=f,
+ mode='lines',
+ line={
+ 'width': 3,
+ 'color': 'blue'
+ },
+ showlegend=False))
+
+ y_max = np.max(sum_freqs)
+
+ attempt_freq, attempt_freq_std = metrics.attempt_frequency()
+
+ if attempt_freq:
+ fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
+ if attempt_freq and attempt_freq_std:
+ fig.add_vline(x=attempt_freq + attempt_freq_std,
+ line={
+ 'width': 2,
+ 'color': 'red',
+ 'dash': 'dash'
+ })
+ fig.add_vline(x=attempt_freq - attempt_freq_std,
+ line={
+ 'width': 2,
+ 'color': 'red',
+ 'dash': 'dash'
+ })
+
+ fig.update_layout(title='Frequency vs Occurence',
+ xaxis_title='Frequency (Hz)',
+ yaxis_title='Occurrence (a.u.)',
+ xaxis_range=[-0.1e13, 2.5e13],
+ yaxis_range=[0, y_max],
+ width=600,
+ height=500)
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_jumps.py b/src/gemdat/plots/plotly/_jumps.py
deleted file mode 100644
index 8c17ae76..00000000
--- a/src/gemdat/plots/plotly/_jumps.py
+++ /dev/null
@@ -1,181 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-import numpy as np
-import pandas as pd
-import plotly.express as px
-import plotly.graph_objects as go
-
-if TYPE_CHECKING:
-
- from gemdat import Jumps
-
-
-def jumps_vs_distance(*,
- jumps: Jumps,
- jump_res: float = 0.1,
- n_parts: int = 1) -> go.Figure:
- """Plot jumps vs. distance histogram.
-
- Parameters
- ----------
- jumps : Jumps
- Input jumps data
- jump_res : float, optional
- Resolution of the bins in Angstrom
- n_parts : int
- Number of parts for error analysis
-
- Returns
- -------
- fig : plotly.graph_objects.Figure
- Output figure
- """
- sites = jumps.sites
- trajectory = jumps.trajectory
- lattice = trajectory.get_lattice()
-
- pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
-
- bin_max = (1 + pdist.max() // jump_res) * jump_res
- n_bins = int(bin_max / jump_res) + 1
- x = np.linspace(0, bin_max, n_bins)
-
- bin_idx = np.digitize(pdist, bins=x)
- data = []
- for transitions_part in jumps.split(n_parts=n_parts):
- counts = np.zeros_like(x)
- for idx, n in zip(bin_idx.flatten(),
- transitions_part.matrix().flatten()):
- counts[idx] += n
- for idx in range(n_bins):
- if counts[idx] > 0:
- data.append((x[idx], counts[idx]))
-
- df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
-
- grouped = df.groupby(['Displacement'])
- mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
- std = grouped.std().reset_index().rename(columns={'count': 'std'})
- df = mean.merge(std, how='inner')
-
- if n_parts == 1:
- fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
- else:
- fig = px.bar(df,
- x='Displacement',
- y='mean',
- error_y='std',
- barmode='stack')
-
- fig.update_layout(title='Jumps vs. Distance',
- xaxis_title='Distance (Angstrom)',
- yaxis_title='Number of jumps')
-
- return fig
-
-
-def jumps_vs_time(*,
- jumps: Jumps,
- bins: int = 8,
- n_parts: int = 1) -> go.Figure:
- """Plot jumps vs. distance histogram.
-
- Parameters
- ----------
- jumps : Jumps
- Input jumps data
- bins : int, optional
- Number of bins
- n_parts : int
- Number of parts for error analysis
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- Output figure
- """
- maxlen = len(jumps.trajectory) / n_parts
- binsize = maxlen / bins + 1
- data = []
-
- for jumps_part in jumps.split(n_parts=n_parts):
- data.append(
- np.histogram(jumps_part.data['start time'],
- bins=bins,
- range=(0., maxlen))[0])
-
- df = pd.DataFrame(data=data)
- columns = [binsize / 2 + binsize * col for col in range(bins)]
-
- mean = [df[col].mean() for col in df.columns]
- std = [df[col].std() for col in df.columns]
-
- df = pd.DataFrame(data=zip(columns, mean, std),
- columns=['time', 'count', 'std'])
-
- if n_parts > 1:
- fig = px.bar(df, x='time', y='count', error_y='std')
- else:
- fig = px.bar(df, x='time', y='count')
-
- fig.update_layout(bargap=0.2,
- title='Jumps vs. time',
- xaxis_title='Time (steps)',
- yaxis_title='Number of jumps')
-
- return fig
-
-
-def collective_jumps(*, jumps: Jumps) -> go.Figure:
- """Plot collective jumps per jump-type combination.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
-
- Returns
- -------
- fig : plotly.graph_objects.Figure
- Output figure
- """
-
- matrix = jumps.collective().site_pair_count_matrix()
- labels = jumps.collective().site_pair_count_matrix_labels()
-
- fig = px.imshow(matrix)
-
- ticks = list(range(len(labels)))
-
- fig.update_layout(xaxis={
- 'tickmode': 'array',
- 'tickvals': ticks,
- 'ticktext': labels
- },
- yaxis={
- 'tickmode': 'array',
- 'tickvals': ticks,
- 'ticktext': labels
- },
- title='Cooperative jumps per jump-type combination')
-
- return fig
-
-
-def jumps_3d(*, jumps: Jumps) -> go.Figure:
- """Plot jumps in 3D.
-
- Parameters
- ----------
- jumps : Jumps
- Input data
-
- Returns
- -------
- fig : plotly.graph_objects.Figure
- Output figure
- """
- from ._plot3d import plot_3d
- return plot_3d(jumps=jumps, structure=jumps.sites)
diff --git a/src/gemdat/plots/plotly/_jumps_3d.py b/src/gemdat/plots/plotly/_jumps_3d.py
new file mode 100644
index 00000000..fcad5ae5
--- /dev/null
+++ b/src/gemdat/plots/plotly/_jumps_3d.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import plotly.graph_objects as go
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_3d(*, jumps: Jumps) -> go.Figure:
+ """Plot jumps in 3D.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input data
+
+ Returns
+ -------
+ fig : plotly.graph_objects.Figure
+ Output figure
+ """
+ from ._plot3d import plot_3d
+ return plot_3d(jumps=jumps, structure=jumps.sites)
diff --git a/src/gemdat/plots/plotly/_jumps_vs_distance.py b/src/gemdat/plots/plotly/_jumps_vs_distance.py
new file mode 100644
index 00000000..3570ba77
--- /dev/null
+++ b/src/gemdat/plots/plotly/_jumps_vs_distance.py
@@ -0,0 +1,75 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_vs_distance(*,
+ jumps: Jumps,
+ jump_res: float = 0.1,
+ n_parts: int = 1) -> go.Figure:
+ """Plot jumps vs. distance histogram.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input jumps data
+ jump_res : float, optional
+ Resolution of the bins in Angstrom
+ n_parts : int
+ Number of parts for error analysis
+
+ Returns
+ -------
+ fig : plotly.graph_objects.Figure
+ Output figure
+ """
+ sites = jumps.sites
+ trajectory = jumps.trajectory
+ lattice = trajectory.get_lattice()
+
+ pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords)
+
+ bin_max = (1 + pdist.max() // jump_res) * jump_res
+ n_bins = int(bin_max / jump_res) + 1
+ x = np.linspace(0, bin_max, n_bins)
+
+ bin_idx = np.digitize(pdist, bins=x)
+ data = []
+ for transitions_part in jumps.split(n_parts=n_parts):
+ counts = np.zeros_like(x)
+ for idx, n in zip(bin_idx.flatten(),
+ transitions_part.matrix().flatten()):
+ counts[idx] += n
+ for idx in range(n_bins):
+ if counts[idx] > 0:
+ data.append((x[idx], counts[idx]))
+
+ df = pd.DataFrame(data=data, columns=['Displacement', 'count'])
+
+ grouped = df.groupby(['Displacement'])
+ mean = grouped.mean().reset_index().rename(columns={'count': 'mean'})
+ std = grouped.std().reset_index().rename(columns={'count': 'std'})
+ df = mean.merge(std, how='inner')
+
+ if n_parts == 1:
+ fig = px.bar(df, x='Displacement', y='mean', barmode='stack')
+ else:
+ fig = px.bar(df,
+ x='Displacement',
+ y='mean',
+ error_y='std',
+ barmode='stack')
+
+ fig.update_layout(title='Jumps vs. Distance',
+ xaxis_title='Distance (Å)',
+ yaxis_title='Number of jumps')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_jumps_vs_time.py b/src/gemdat/plots/plotly/_jumps_vs_time.py
new file mode 100644
index 00000000..8116d1c6
--- /dev/null
+++ b/src/gemdat/plots/plotly/_jumps_vs_time.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+
+if TYPE_CHECKING:
+ from gemdat import Jumps
+
+
+def jumps_vs_time(*,
+ jumps: Jumps,
+ bins: int = 8,
+ n_parts: int = 1) -> go.Figure:
+ """Plot jumps vs. distance histogram.
+
+ Parameters
+ ----------
+ jumps : Jumps
+ Input jumps data
+ bins : int, optional
+ Number of bins
+ n_parts : int
+ Number of parts for error analysis
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+ maxlen = len(jumps.trajectory) / n_parts
+ binsize = maxlen / bins + 1
+ data = []
+
+ for jumps_part in jumps.split(n_parts=n_parts):
+ data.append(
+ np.histogram(jumps_part.data['start time'],
+ bins=bins,
+ range=(0., maxlen))[0])
+
+ df = pd.DataFrame(data=data)
+ columns = [binsize / 2 + binsize * col for col in range(bins)]
+
+ mean = [df[col].mean() for col in df.columns]
+ std = [df[col].std() for col in df.columns]
+
+ df = pd.DataFrame(data=zip(columns, mean, std),
+ columns=['time', 'count', 'std'])
+
+ if n_parts > 1:
+ fig = px.bar(df, x='time', y='count', error_y='std')
+ else:
+ fig = px.bar(df, x='time', y='count')
+
+ fig.update_layout(bargap=0.2,
+ title='Jumps vs. time',
+ xaxis_title='Time (steps)',
+ yaxis_title='Number of jumps')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_msd_per_element.py b/src/gemdat/plots/plotly/_msd_per_element.py
new file mode 100644
index 00000000..6b9ad729
--- /dev/null
+++ b/src/gemdat/plots/plotly/_msd_per_element.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+import numpy as np
+import plotly.graph_objects as go
+
+from gemdat.trajectory import Trajectory
+
+
+def msd_per_element(*, trajectory: Trajectory) -> go.Figure:
+ """Plot mean squared displacement per element.
+
+ Parameters
+ ----------
+ trajectory : Trajectory
+ Input trajectory
+
+ Returns
+ -------
+ fig : matplotlib.figure.Figure
+ Output figure
+ """
+
+ fig = go.Figure()
+
+ species = list(set(trajectory.species))
+
+ # Since we want to plot in picosecond, we convert the time units
+ time_ps = trajectory.time_step * 1e12
+
+ for sp in species:
+ traj = trajectory.filter(sp.symbol)
+
+ msd = traj.mean_squared_displacement()
+ msd_mean = np.mean(msd, axis=0)
+ msd_std = np.std(msd, axis=0)
+ t_values = np.arange(len(msd_mean)) * time_ps
+
+ fig.add_trace(
+ go.Scatter(x=t_values,
+ y=msd_mean,
+ error_y=dict(type='data',
+ array=msd_std,
+ width=0.1,
+ thickness=0.1),
+ name=f'{sp.symbol} mean+std',
+ mode='lines',
+ line={'width': 3},
+ legendgroup=sp.symbol))
+
+ fig.update_layout(title='Mean squared displacement per element',
+ xaxis_title='Time lag (ps)',
+ yaxis_title=r'MSD (Å2)')
+
+ return fig
diff --git a/src/gemdat/plots/plotly/_plot3d.py b/src/gemdat/plots/plotly/_plot3d.py
index a919d0c9..4f5d6785 100644
--- a/src/gemdat/plots/plotly/_plot3d.py
+++ b/src/gemdat/plots/plotly/_plot3d.py
@@ -307,9 +307,9 @@ def update_layout(*,
'y': lattice.b * zoom,
'z': lattice.c * zoom,
},
- 'xaxis_title': 'X (Ångstrom)',
- 'yaxis_title': 'Y (Ångstrom)',
- 'zaxis_title': 'Z (Ångstrom)'
+ 'xaxis_title': 'X (Å)',
+ 'yaxis_title': 'Y (Å)',
+ 'zaxis_title': 'Z (Å)'
},
legend={
'orientation': 'h',
diff --git a/src/gemdat/plots/plotly/_vibration.py b/src/gemdat/plots/plotly/_vibrational_amplitudes.py
similarity index 52%
rename from src/gemdat/plots/plotly/_vibration.py
rename to src/gemdat/plots/plotly/_vibrational_amplitudes.py
index 05eb9796..d11d0e15 100644
--- a/src/gemdat/plots/plotly/_vibration.py
+++ b/src/gemdat/plots/plotly/_vibrational_amplitudes.py
@@ -10,77 +10,6 @@
from gemdat.trajectory import Trajectory
-def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure:
- """Plot attempt frequency vs occurence.
-
- Parameters
- ----------
- trajectory : Trajectory
- Input trajectory, i.e. for the diffusing atom
-
- Returns
- -------
- fig : go.figure.Figure
- Output figure
- """
- metrics = SimulationMetrics(trajectory)
- speed = metrics.speed()
-
- length = speed.shape[1]
- half_length = length // 2 + 1
-
- trans = np.fft.fft(speed)
-
- two_sided = np.abs(trans / length)
- one_sided = two_sided[:, :half_length]
-
- fig = go.Figure()
-
- f = trajectory.sampling_frequency * np.arange(half_length) / length
-
- sum_freqs = np.sum(one_sided, axis=0)
- smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51
- fig.add_trace(
- go.Scatter(y=smoothed,
- x=f,
- mode='lines',
- line={
- 'width': 3,
- 'color': 'blue'
- },
- showlegend=False))
-
- y_max = np.max(sum_freqs)
-
- attempt_freq, attempt_freq_std = metrics.attempt_frequency()
-
- if attempt_freq:
- fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'})
- if attempt_freq and attempt_freq_std:
- fig.add_vline(x=attempt_freq + attempt_freq_std,
- line={
- 'width': 2,
- 'color': 'red',
- 'dash': 'dash'
- })
- fig.add_vline(x=attempt_freq - attempt_freq_std,
- line={
- 'width': 2,
- 'color': 'red',
- 'dash': 'dash'
- })
-
- fig.update_layout(title='Frequency vs Occurence',
- xaxis_title='Frequency (Hz)',
- yaxis_title='Occurrence (a.u.)',
- xaxis_range=[-0.1e13, 2.5e13],
- yaxis_range=[0, y_max],
- width=600,
- height=500)
-
- return fig
-
-
def vibrational_amplitudes(*,
trajectory: Trajectory,
bins: int = 50,
@@ -149,7 +78,7 @@ def vibrational_amplitudes(*,
fig.update_layout(
title='Histogram of vibrational amplitudes with fitted Gaussian',
- xaxis_title='Amplitude (Ångstrom)',
+ xaxis_title='Amplitude (Å)',
yaxis_title='Occurrence (a.u.)')
return fig
diff --git a/tests/integration/baseline_images/plot_test/autocorrelation.png b/tests/integration/baseline_images/plot_test/autocorrelation.png
new file mode 100644
index 00000000..c8afd429
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/autocorrelation.png differ
diff --git a/tests/integration/baseline_images/plot_test/bond_length_distribution.png b/tests/integration/baseline_images/plot_test/bond_length_distribution.png
index 59475ea8..51320cc6 100644
Binary files a/tests/integration/baseline_images/plot_test/bond_length_distribution.png and b/tests/integration/baseline_images/plot_test/bond_length_distribution.png differ
diff --git a/tests/integration/baseline_images/plot_test/displacement_histogram.png b/tests/integration/baseline_images/plot_test/displacement_histogram.png
index 1fa9ae53..7e84f4ba 100644
Binary files a/tests/integration/baseline_images/plot_test/displacement_histogram.png and b/tests/integration/baseline_images/plot_test/displacement_histogram.png differ
diff --git a/tests/integration/baseline_images/plot_test/displacement_per_atom.png b/tests/integration/baseline_images/plot_test/displacement_per_atom.png
index 1f2b6e33..45e22f7c 100644
Binary files a/tests/integration/baseline_images/plot_test/displacement_per_atom.png and b/tests/integration/baseline_images/plot_test/displacement_per_atom.png differ
diff --git a/tests/integration/baseline_images/plot_test/displacement_per_element.png b/tests/integration/baseline_images/plot_test/displacement_per_element.png
index 406cfbc1..38a36fdb 100644
Binary files a/tests/integration/baseline_images/plot_test/displacement_per_element.png and b/tests/integration/baseline_images/plot_test/displacement_per_element.png differ
diff --git a/tests/integration/baseline_images/plot_test/energy_along_path.png b/tests/integration/baseline_images/plot_test/energy_along_path.png
new file mode 100644
index 00000000..fc8f3e28
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/energy_along_path.png differ
diff --git a/tests/integration/baseline_images/plot_test/jumps_3d.png b/tests/integration/baseline_images/plot_test/jumps_3d.png
index acef3067..bbaa4597 100644
Binary files a/tests/integration/baseline_images/plot_test/jumps_3d.png and b/tests/integration/baseline_images/plot_test/jumps_3d.png differ
diff --git a/tests/integration/baseline_images/plot_test/jumps_3d_animation.png b/tests/integration/baseline_images/plot_test/jumps_3d_animation.png
index 2cc85801..5e0ff8b2 100644
Binary files a/tests/integration/baseline_images/plot_test/jumps_3d_animation.png and b/tests/integration/baseline_images/plot_test/jumps_3d_animation.png differ
diff --git a/tests/integration/baseline_images/plot_test/jumps_vs_distance.png b/tests/integration/baseline_images/plot_test/jumps_vs_distance.png
index 2eabe155..28d78e5c 100644
Binary files a/tests/integration/baseline_images/plot_test/jumps_vs_distance.png and b/tests/integration/baseline_images/plot_test/jumps_vs_distance.png differ
diff --git a/tests/integration/baseline_images/plot_test/msd.png b/tests/integration/baseline_images/plot_test/msd.png
deleted file mode 100644
index fd479f23..00000000
Binary files a/tests/integration/baseline_images/plot_test/msd.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/msd_per_element.png b/tests/integration/baseline_images/plot_test/msd_per_element.png
new file mode 100644
index 00000000..89a8ed6a
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/msd_per_element.png differ
diff --git a/tests/integration/baseline_images/plot_test/path_energy.png b/tests/integration/baseline_images/plot_test/path_energy.png
deleted file mode 100644
index 08494d4b..00000000
Binary files a/tests/integration/baseline_images/plot_test/path_energy.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_1.png b/tests/integration/baseline_images/plot_test/radial_distribution_1.png
new file mode 100644
index 00000000..6ace96a0
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_1.png differ
diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_2.png b/tests/integration/baseline_images/plot_test/radial_distribution_2.png
new file mode 100644
index 00000000..8a0fab15
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_2.png differ
diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_3.png b/tests/integration/baseline_images/plot_test/radial_distribution_3.png
new file mode 100644
index 00000000..8daa103e
Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_3.png differ
diff --git a/tests/integration/baseline_images/plot_test/rdf1.png b/tests/integration/baseline_images/plot_test/rdf1.png
deleted file mode 100644
index ef78e5ee..00000000
Binary files a/tests/integration/baseline_images/plot_test/rdf1.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/rdf2.png b/tests/integration/baseline_images/plot_test/rdf2.png
deleted file mode 100644
index 9e5a7011..00000000
Binary files a/tests/integration/baseline_images/plot_test/rdf2.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/rdf3.png b/tests/integration/baseline_images/plot_test/rdf3.png
deleted file mode 100644
index 463e47df..00000000
Binary files a/tests/integration/baseline_images/plot_test/rdf3.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/rectilinear.png b/tests/integration/baseline_images/plot_test/rectilinear.png
index 2fd5123b..fa21df9b 100644
Binary files a/tests/integration/baseline_images/plot_test/rectilinear.png and b/tests/integration/baseline_images/plot_test/rectilinear.png differ
diff --git a/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png b/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png
deleted file mode 100644
index d7cb667c..00000000
Binary files a/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png and /dev/null differ
diff --git a/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png b/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png
index 1624bb8e..1571ca6f 100644
Binary files a/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png and b/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png differ
diff --git a/tests/integration/plot_test.py b/tests/integration/plot_test.py
index ab4a7a19..ca2ecd8b 100644
--- a/tests/integration/plot_test.py
+++ b/tests/integration/plot_test.py
@@ -61,7 +61,9 @@ def test_jumps_3d_animation(vasp_jumps):
plots.jumps_3d_animation(jumps=vasp_jumps, t_start=1000, t_stop=1001)
-@image_comparison2(baseline_images=['rdf1', 'rdf2', 'rdf3'])
+@image_comparison2(baseline_images=[
+ 'radial_distribution_1', 'radial_distribution_2', 'radial_distribution_3'
+])
def test_rdf(vasp_rdf_data):
assert len(vasp_rdf_data) == 3
for rdfs in vasp_rdf_data.values():
@@ -75,12 +77,12 @@ def test_shape(vasp_shape_data):
plots.shape(shape)
-@image_comparison2(baseline_images=['msd'])
+@image_comparison2(baseline_images=['msd_per_element'])
def test_msd_per_element(vasp_traj):
plots.msd_per_element(trajectory=vasp_traj[-500:])
-@image_comparison2(baseline_images=['path_energy'])
+@image_comparison2(baseline_images=['energy_along_path'])
def test_path_energy(vasp_path):
structure = load_known_material('argyrodite')
plots.energy_along_path(path=vasp_path, structure=structure)
@@ -101,6 +103,6 @@ def test_bond_length_distribution(vasp_orientations):
vasp_orientations.plot_bond_length_distribution(bins=1000)
-@image_comparison2(baseline_images=['unit_vector_autocorrelation'])
+@image_comparison2(baseline_images=['autocorrelation'])
def test_unit_vector_autocorrelation(vasp_orientations):
vasp_orientations.plot_autocorrelation()
diff --git a/tests/plots_tests.py b/tests/plots_tests.py
new file mode 100644
index 00000000..45306ae5
--- /dev/null
+++ b/tests/plots_tests.py
@@ -0,0 +1,8 @@
+def test_matplotlib_imports():
+ from gemdat.plots import matplotlib
+ assert matplotlib.__all__
+
+
+def test_plotly_imports():
+ from gemdat.plots import plotly
+ assert plotly.__all__