diff --git a/src/gemdat/plots/_shared.py b/src/gemdat/plots/_shared.py new file mode 100644 index 00000000..4da40845 --- /dev/null +++ b/src/gemdat/plots/_shared.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import numpy as np +from typing import TYPE_CHECKING +from collections import defaultdict +if TYPE_CHECKING: + from gemdat.trajectory import Trajectory + + +def _mean_displacements_per_element( + trajectory: Trajectory) -> dict[str, tuple[np.ndarray, np.ndarray]]: + """Calculate mean displacements per element type. + + Helper for `displacement_per_atom`. + """ + species = trajectory.species + + grouped = defaultdict(list) + for sp, distances in zip(species, + trajectory.distances_from_base_position()): + grouped[sp.symbol].append(distances) + + means = {} + for sp, dists in grouped.items(): + mean = np.mean(dists, axis=0) + std = np.std(dists, axis=0) + + means[sp] = (mean, std) + + return means diff --git a/src/gemdat/plots/matplotlib/__init__.py b/src/gemdat/plots/matplotlib/__init__.py index cc90af17..ba17f569 100644 --- a/src/gemdat/plots/matplotlib/__init__.py +++ b/src/gemdat/plots/matplotlib/__init__.py @@ -2,34 +2,29 @@ matplotlib.""" from __future__ import annotations -from ._displacements import ( - displacement_histogram, - displacement_per_atom, - displacement_per_element, - msd_per_element, -) -from ._jumps import ( - collective_jumps, - jumps_3d, - jumps_3d_animation, - jumps_vs_distance, - jumps_vs_time, -) -from ._paths import ( - energy_along_path, - path_on_grid, -) -from ._rdf import radial_distribution -from ._orientations import ( - autocorrelation, - bond_length_distribution, - rectilinear, -) +from ._displacement_histogram import displacement_histogram +from ._displacement_per_atom import displacement_per_atom +from ._displacement_per_element import displacement_per_element +from ._msd_per_element import msd_per_element + +from ._collective_jumps import collective_jumps +from ._jumps_3d import jumps_3d +from ._jumps_3d_animation import jumps_3d_animation +from ._jumps_vs_distance import jumps_vs_distance +from ._jumps_vs_time import jumps_vs_time + +from ._energy_along_path import energy_along_path +from ._path_on_grid import path_on_grid +from ._radial_distribution import radial_distribution + +from ._autocorrelation import autocorrelation +from ._bond_length_distribution import bond_length_distribution +from ._rectilinear import rectilinear + from ._shape import shape -from ._vibration import ( - frequency_vs_occurence, - vibrational_amplitudes, -) + +from ._frequency_vs_occurence import frequency_vs_occurence +from ._vibrational_amplitudes import vibrational_amplitudes __all__ = [ 'bond_length_distribution', diff --git a/src/gemdat/plots/matplotlib/_autocorrelation.py b/src/gemdat/plots/matplotlib/_autocorrelation.py new file mode 100644 index 00000000..1d9770c7 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_autocorrelation.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np + +from gemdat.orientations import ( + Orientations, ) + + +def autocorrelation( + *, + orientations: Orientations, + show_traces: bool = True, +) -> plt.Figure: + """Plot the autocorrelation function of the unit vectors series. + + Parameters + ---------- + orientations : Orientations + The unit vector trajectories + show_traces : bool + If True, show traces of individual trajectories + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + ac = orientations.autocorrelation() + ac_std = ac.std(axis=0) + ac_mean = ac.mean(axis=0) + + # Since we want to plot in picosecond, we convert the time units + time_ps = orientations._time_step * 1e12 + tgrid = np.arange(ac_mean.shape[0]) * time_ps + + fig, ax = plt.subplots() + + ax.plot(tgrid, ac_mean, label='FFT autocorrelation') + + last_color = ax.lines[-1].get_color() + + if show_traces: + for i, ac_i in enumerate(ac): + label = 'Trajectories' if (i == 0) else None + ax.plot(tgrid, ac_i, lw=0.1, c=last_color, label=label) + + ax.fill_between(tgrid, + ac_mean - ac_std, + ac_mean + ac_std, + alpha=0.2, + label='Standard deviation') + ax.set_xlabel('Time lag (ps)') + ax.set_ylabel('Autocorrelation') + ax.legend() + + return fig diff --git a/src/gemdat/plots/matplotlib/_bond_length_distribution.py b/src/gemdat/plots/matplotlib/_bond_length_distribution.py new file mode 100644 index 00000000..94f62812 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_bond_length_distribution.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +from scipy.optimize import curve_fit +from scipy.stats import skewnorm + +from gemdat.orientations import Orientations + + +def bond_length_distribution(*, + orientations: Orientations, + bins: int = 1000) -> plt.Figure: + """Plot the bond length probability distribution. + + Parameters + ---------- + orientations : Orientations + The unit vector trajectories + bins : int, optional + The number of bins, by default 1000 + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + *_, bond_lengths = orientations.vectors_spherical.T + bond_lengths = bond_lengths.flatten() + + fig, ax = plt.subplots() + + # Plot the normalized histogram + hist, edges = np.histogram(bond_lengths, bins=bins, density=True) + bin_centers = (edges[:-1] + edges[1:]) / 2 + + # Fit a skewed Gaussian distribution to the orientations + params, covariance = curve_fit(skewnorm.pdf, + bin_centers, + hist, + p0=[1.5, 1, 1.5]) + + # Create a new function using the fitted parameters + def _skewnorm_fit(x): + return skewnorm.pdf(x, *params) + + # Plot the histogram + ax.hist(bond_lengths, + bins=bins, + density=True, + color='blue', + alpha=0.7, + label='Data') + + # Plot the fitted skewed Gaussian distribution + x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000) + ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit') + + ax.set_xlabel('Bond length (Å)') + ax.set_ylabel(r'Probability density (Å$^{-1}$)') + ax.set_title('Bond Length Probability Distribution') + ax.legend() + ax.grid(True) + + return fig diff --git a/src/gemdat/plots/matplotlib/_collective_jumps.py b/src/gemdat/plots/matplotlib/_collective_jumps.py new file mode 100644 index 00000000..cebb8418 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_collective_jumps.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt + +if TYPE_CHECKING: + from gemdat import Jumps + + +def collective_jumps(*, jumps: Jumps) -> plt.Figure: + """Plot collective jumps per jump-type combination. + + Parameters + ---------- + jumps : Jumps + Input data + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + fig, ax = plt.subplots() + + collective = jumps.collective() + + matrix = collective.site_pair_count_matrix() + + im = ax.imshow(matrix) + + labels = collective.site_pair_count_matrix_labels() + ticks = range(len(labels)) + + ax.set_xticks(ticks, labels=labels, rotation=90) + ax.set_yticks(ticks, labels=labels) + + fig.colorbar(im, ax=ax) + + ax.set(title='Cooperative jumps per jump-type combination') + + return fig diff --git a/src/gemdat/plots/matplotlib/_displacement_histogram.py b/src/gemdat/plots/matplotlib/_displacement_histogram.py new file mode 100644 index 00000000..af214536 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_displacement_histogram.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt + +from gemdat.trajectory import Trajectory + + +def displacement_histogram(trajectory: Trajectory) -> plt.Figure: + """Plot histogram of total displacement at final timestep. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + fig, ax = plt.subplots() + ax.hist(trajectory.distances_from_base_position()[:, -1]) + ax.set(title='Displacement per element', + xlabel='Displacement (Å)', + ylabel='Nr. of atoms') + + return fig diff --git a/src/gemdat/plots/matplotlib/_displacement_per_atom.py b/src/gemdat/plots/matplotlib/_displacement_per_atom.py new file mode 100644 index 00000000..2b33ae76 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_displacement_per_atom.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt + +from gemdat.trajectory import Trajectory + + +def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure: + """Plot displacement per atom. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + fig, ax = plt.subplots() + + for distances in trajectory.distances_from_base_position(): + ax.plot(distances, lw=0.3) + + ax.set(title='Displacement per site', + xlabel='Time step', + ylabel='Displacement (Å)') + + return fig diff --git a/src/gemdat/plots/matplotlib/_displacement_per_element.py b/src/gemdat/plots/matplotlib/_displacement_per_element.py new file mode 100644 index 00000000..626c7248 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_displacement_per_element.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt + +from gemdat.trajectory import Trajectory + +from gemdat.plots._shared import _mean_displacements_per_element + + +def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure: + """Plot displacement per element. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + displacements = _mean_displacements_per_element(trajectory) + + fig, ax = plt.subplots() + + for symbol, (mean, _) in displacements.items(): + ax.plot(mean, lw=0.3, label=symbol) + + ax.legend() + ax.set(title='Displacement per element', + xlabel='Time step', + ylabel='Displacement (Å)') + + return fig diff --git a/src/gemdat/plots/matplotlib/_displacements.py b/src/gemdat/plots/matplotlib/_displacements.py deleted file mode 100644 index 6508db4c..00000000 --- a/src/gemdat/plots/matplotlib/_displacements.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict - -import matplotlib.pyplot as plt -import numpy as np - -from gemdat.trajectory import Trajectory - - -def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure: - """Plot displacement per atom. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - fig, ax = plt.subplots() - - for distances in trajectory.distances_from_base_position(): - ax.plot(distances, lw=0.3) - - ax.set(title='Displacement per site', - xlabel='Time step', - ylabel='Displacement (Angstrom)') - - return fig - - -def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure: - """Plot displacement per element. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - grouped = defaultdict(list) - - species = trajectory.species - - for sp, distances in zip(species, - trajectory.distances_from_base_position()): - grouped[sp.symbol].append(distances) - - fig, ax = plt.subplots() - - for symbol, distances in grouped.items(): - mean_disp = np.mean(distances, axis=0) - ax.plot(mean_disp, lw=0.3, label=symbol) - - ax.legend() - ax.set(title='Displacement per element', - xlabel='Time step', - ylabel='Displacement (Angstrom)') - - return fig - - -def msd_per_element( - *, - trajectory: Trajectory, -) -> plt.Figure: - """Plot mean squared displacement per element. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - species = list(set(trajectory.species)) - - fig, ax = plt.subplots() - - # Since we want to plot in picosecond, we convert the time units - time_ps = trajectory.time_step * 1e12 - - for sp in species: - traj = trajectory.filter(sp.symbol) - msd = traj.mean_squared_displacement() - msd_mean = np.mean(msd, axis=0) - msd_std = np.std(msd, axis=0) - t_values = np.arange(len(msd_mean)) * time_ps - ax.plot(t_values, msd_mean, lw=0.5, label=sp.symbol) - last_color = ax.lines[-1].get_color() - ax.fill_between(t_values, - msd_mean - msd_std, - msd_mean + msd_std, - color=last_color, - alpha=0.2) - - ax.legend() - ax.set(title='Mean squared displacement per element', - xlabel='Time lag [ps]', - ylabel='MSD (Angstrom$^2$)') - - return fig - - -def displacement_histogram(trajectory: Trajectory) -> plt.Figure: - """Plot histogram of total displacement at final timestep. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - fig, ax = plt.subplots() - ax.hist(trajectory.distances_from_base_position()[:, -1]) - ax.set(title='Histogram of displacements', - xlabel='Displacement (Angstrom)', - ylabel='Nr. of atoms') - - return fig diff --git a/src/gemdat/plots/matplotlib/_paths.py b/src/gemdat/plots/matplotlib/_energy_along_path.py similarity index 75% rename from src/gemdat/plots/matplotlib/_paths.py rename to src/gemdat/plots/matplotlib/_energy_along_path.py index bb40fa0a..370248c2 100644 --- a/src/gemdat/plots/matplotlib/_paths.py +++ b/src/gemdat/plots/matplotlib/_energy_along_path.py @@ -32,7 +32,7 @@ def energy_along_path( fig, ax = plt.subplots(figsize=(8, 4)) ax.plot(path.energy, marker='o', color='r', label='Optimal path') - ax.set(ylabel='Free energy [eV]') + ax.set(ylabel='Free energy (eV)') nearest_sites = path.path_over_structure(structure) @@ -98,43 +98,3 @@ def energy_along_path( ax.legend(fontsize=8) return fig - - -def path_on_grid(path: Pathway) -> plt.Figure: - """Plot the 3d coordinates of the points that define a path. - - Parameters - ---------- - path : Pathway - Pathway to plot - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - # Create a colormap to visualize the path - colormap = plt.get_cmap() - normalize = plt.Normalize(0, len(path.energy)) - - fig, ax = plt.subplots() - ax = fig.add_subplot(111, projection='3d') - - path_x, path_y, path_z = zip(*path.sites) - - for i in range(len(path.energy) - 1): - ax.plot(path_x[i:i + 1], - path_y[i:i + 1], - path_z[i:i + 1], - color=colormap(normalize(i)), - marker='o', - linestyle='-') - - ax.set_xlabel('X') - ax.set_ylabel('Y') - sm = plt.cm.ScalarMappable(cmap=colormap, norm=normalize) - sm.set_array([]) - cbar = plt.colorbar(sm, ax=ax) - cbar.set_label('Steps') - - return fig diff --git a/src/gemdat/plots/matplotlib/_vibration.py b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py similarity index 66% rename from src/gemdat/plots/matplotlib/_vibration.py rename to src/gemdat/plots/matplotlib/_frequency_vs_occurence.py index e5e2bb7e..c89531f2 100644 --- a/src/gemdat/plots/matplotlib/_vibration.py +++ b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import numpy as np -from scipy import stats from gemdat.simulation_metrics import SimulationMetrics from gemdat.trajectory import Trajectory @@ -62,32 +61,3 @@ def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure: ax.set_xlim([-0.1e13, 2.5e13]) return fig - - -def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure: - """Plot histogram of vibrational amplitudes with fitted Gaussian. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - metrics = SimulationMetrics(trajectory) - - fig, ax = plt.subplots() - ax.hist(metrics.amplitudes(), bins=100, density=True) - - x = np.linspace(-2, 2, 100) - y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude()) - ax.plot(x, y_gauss, 'r') - - ax.set(title='Histogram of vibrational amplitudes with fitted Gaussian', - xlabel='Amplitude (Angstrom)', - ylabel='Occurrence (a.u.)') - - return fig diff --git a/src/gemdat/plots/matplotlib/_jumps.py b/src/gemdat/plots/matplotlib/_jumps.py deleted file mode 100644 index f5db9306..00000000 --- a/src/gemdat/plots/matplotlib/_jumps.py +++ /dev/null @@ -1,347 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import matplotlib.animation as animation -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import colormaps -from pymatgen.electronic_structure import plotter - -if TYPE_CHECKING: - - from gemdat import Jumps - - -def jumps_vs_distance( - *, - jumps: Jumps, - jump_res: float = 0.1, -) -> plt.Figure: - """Plot jumps vs. distance histogram. - - Parameters - ---------- - jumps : Jumps - Input data - jump_res : float, optional - Resolution of the bins in Angstrom - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - sites = jumps.sites - - trajectory = jumps.trajectory - lattice = trajectory.get_lattice() - - pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords) - - bin_max = (1 + pdist.max() // jump_res) * jump_res - n_bins = int(bin_max / jump_res) + 1 - x = np.linspace(0, bin_max, n_bins) - counts = np.zeros_like(x) - - bin_idx = np.digitize(pdist, bins=x) - for idx, n in zip(bin_idx.flatten(), jumps.matrix().flatten()): - counts[idx] += n - - fig, ax = plt.subplots() - - ax.bar(x, counts, width=(jump_res * 0.8)) - - ax.set(title='Jumps vs. Distance', - xlabel='Distance (Angstrom)', - ylabel='Number of jumps') - - return fig - - -def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure: - """Plot jumps vs. time histogram. - - Parameters - ---------- - jumps : Jumps - Input data - binsize : int, optional - Width of each bin in number of time steps - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - - trajectory = jumps.trajectory - - n_steps = len(trajectory) - bins = np.arange(0, n_steps + binsize, binsize) - - fig, ax = plt.subplots() - - ax.hist(jumps.data['stop time'], bins=bins, width=0.8 * binsize) - - ax.set(title='Jumps vs. time', - xlabel='Time (steps)', - ylabel='Number of jumps') - - return fig - - -def collective_jumps(*, jumps: Jumps) -> plt.Figure: - """Plot collective jumps per jump-type combination. - - Parameters - ---------- - jumps : Jumps - Input data - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - fig, ax = plt.subplots() - matrix = jumps.collective().site_pair_count_matrix() - labels = jumps.collective().site_pair_count_matrix_labels() - - mat = ax.imshow(matrix) - - ticks = range(len(labels)) - - ax.set_xticks(ticks, labels=labels, rotation=90) - ax.set_yticks(ticks, labels=labels) - - fig.colorbar(mat, ax=ax) - - ax.set(title='Cooperative jumps per jump-type combination') - - return fig - - -def jumps_3d(*, jumps: Jumps) -> plt.Figure: - """Plot jumps in 3D. - - Parameters - ---------- - jumps : Jumps - Input data - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - trajectory = jumps.trajectory - sites = jumps.sites - - class LabelItems: - - def __init__(self, labels, coords): - self.labels = labels - self.coords = coords - - def items(self): - yield from zip(self.labels, self.coords) - - coords = sites.frac_coords - lattice = trajectory.get_lattice() - - fig = plt.figure() - ax = fig.add_subplot(projection='3d') - - site_labels = LabelItems(jumps.sites.labels, coords) - - xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1], - [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]]) - - plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1) - plotter.plot_labels(xyz_labels, - lattice=lattice, - ax=ax, - color='green', - size=12) - plotter.plot_points(coords, lattice=lattice, ax=ax) - - for i, j in zip(*np.triu_indices(len(coords), k=1)): - count = jumps.matrix()[i, j] + jumps.matrix()[j, i] - if count == 0: - continue - - coord_i = coords[i] - coord_j = coords[j] - - lw = 1 + np.log(count) - - length, image = lattice.get_distance_and_image(coord_i, coord_j) - - # NOTE: might need to plot `line = [coord_i - image, coord_j]` as well - if np.any(image != 0): - lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)] - else: - lines = [(coord_i, coord_j)] - - for line in lines: - plotter.plot_path(line, - lattice=lattice, - ax=ax, - color='red', - linewidth=lw) - - plotter.plot_labels(site_labels, - lattice=lattice, - ax=ax, - color='black', - size=8) - - ax.set( - title='Jumps between sites', - xlabel="x' (ang)", - ylabel="y' (ang)", - zlabel="z' (ang)", - ) - - ax.set_aspect('equal') # only auto is supported - - return fig - - -def jumps_3d_animation( - *, - jumps: Jumps, - t_start: int, - t_stop: int, - decay: float = 0.05, - skip: int = 5, - interval: int = 20, -) -> animation.FuncAnimation: - """Plot jumps in 3D as an animation over time. - - Parameters - ---------- - jumps : Jumps - Input data - t_start : int - Time step to start animation (relative to equilibration time) - t_stop : int - Time step to stop animation (relative to equilibration time) - decay : float, optional - Controls the decay of the line width (higher = faster decay) - skip : float, optional - Skip frames (increase for faster, but less accurate rendering) - interval : int, optional - Delay between frames in milliseconds. - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - minwidth = 0.2 - maxwidth = 5.0 - - trajectory = jumps.trajectory - - class LabelItems: - - def __init__(self, labels, coords): - self.labels = labels - self.coords = coords - - def items(self): - yield from zip(self.labels, self.coords) - - coords = jumps.sites.frac_coords - lattice = trajectory.get_lattice() - - color_from = colormaps['Set1'].colors # type: ignore - color_to = colormaps['Pastel1'].colors # type: ignore - - fig = plt.figure() - ax = fig.add_subplot(projection='3d') - - xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1], - [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]]) - - plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1) - - plotter.plot_labels(xyz_labels, - lattice=lattice, - ax=ax, - color='green', - size=12) - - assert len(ax.collections) == 0 - plotter.plot_points(coords, - lattice=lattice, - ax=ax, - s=50, - color='white', - edgecolor='black') - points = ax.collections - - events = jumps.data.sort_values('start time', ignore_index=True) - - for _, event in events.iterrows(): - site_i = event['start site'] - site_j = event['destination site'] - - coord_i = coords[site_i] - coord_j = coords[site_j] - - lw = 0 - - _, image = lattice.get_distance_and_image(coord_i, coord_j) - - line = [coord_i, coord_j + image] - - plotter.plot_path(line, - lattice=lattice, - ax=ax, - color='red', - linewidth=lw) - - lines = ax.lines[3:] - - ax.set( - title='Jumps between sites', - xlabel="x' (ang)", - ylabel="y' (ang)", - zlabel="z' (ang)", - ) - - ax.set_aspect('equal') # only auto is supported - - def update(frame_no): - t_frame = t_start + (frame_no * skip) - - for i, event in events.iterrows(): - - if event['start time'] > t_frame: - break - - lw = max(maxwidth - decay * (t_frame - event['start time']), - minwidth) - - line = lines[i] - line.set_color('red') - line.set_linewidth(lw) - - points[event['start site']].set_facecolor( - color_from[event['atom index'] % len(color_from)]) - points[event['destination site']].set_facecolor( - color_to[event['atom index'] % len(color_to)]) - - start_time = event['start time'] - ax.set_title(f'T: {t_frame} | Next jump: {start_time}') - - n_frames = int((t_stop - t_start) / skip) - - return animation.FuncAnimation(fig=fig, - func=update, - frames=n_frames, - interval=interval, - repeat=False) diff --git a/src/gemdat/plots/matplotlib/_jumps_3d.py b/src/gemdat/plots/matplotlib/_jumps_3d.py new file mode 100644 index 00000000..4a540519 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_jumps_3d.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np +from pymatgen.electronic_structure import plotter + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_3d(*, jumps: Jumps) -> plt.Figure: + """Plot jumps in 3D. + + Parameters + ---------- + jumps : Jumps + Input data + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + trajectory = jumps.trajectory + sites = jumps.sites + + class LabelItems: + + def __init__(self, labels, coords): + self.labels = labels + self.coords = coords + + def items(self): + yield from zip(self.labels, self.coords) + + coords = sites.frac_coords + lattice = trajectory.get_lattice() + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + site_labels = LabelItems(jumps.sites.labels, coords) + + xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1], + [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]]) + + plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1) + plotter.plot_labels(xyz_labels, + lattice=lattice, + ax=ax, + color='green', + size=12) + plotter.plot_points(coords, lattice=lattice, ax=ax) + + for i, j in zip(*np.triu_indices(len(coords), k=1)): + count = jumps.matrix()[i, j] + jumps.matrix()[j, i] + if count == 0: + continue + + coord_i = coords[i] + coord_j = coords[j] + + lw = 1 + np.log(count) + + length, image = lattice.get_distance_and_image(coord_i, coord_j) + + # NOTE: might need to plot `line = [coord_i - image, coord_j]` as well + if np.any(image != 0): + lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)] + else: + lines = [(coord_i, coord_j)] + + for line in lines: + plotter.plot_path(line, + lattice=lattice, + ax=ax, + color='red', + linewidth=lw) + + plotter.plot_labels(site_labels, + lattice=lattice, + ax=ax, + color='black', + size=8) + + ax.set( + title='Jumps between sites', + xlabel="x' (Å)", + ylabel="y' (Å)", + zlabel="z' (Å)", + ) + + ax.set_aspect('equal') # only auto is supported + + return fig diff --git a/src/gemdat/plots/matplotlib/_jumps_3d_animation.py b/src/gemdat/plots/matplotlib/_jumps_3d_animation.py new file mode 100644 index 00000000..74031e95 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_jumps_3d_animation.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.animation as animation +import matplotlib.pyplot as plt +from matplotlib import colormaps +from pymatgen.electronic_structure import plotter + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_3d_animation( + *, + jumps: Jumps, + t_start: int, + t_stop: int, + decay: float = 0.05, + skip: int = 5, + interval: int = 20, +) -> animation.FuncAnimation: + """Plot jumps in 3D as an animation over time. + + Parameters + ---------- + jumps : Jumps + Input data + t_start : int + Time step to start animation (relative to equilibration time) + t_stop : int + Time step to stop animation (relative to equilibration time) + decay : float, optional + Controls the decay of the line width (higher = faster decay) + skip : float, optional + Skip frames (increase for faster, but less accurate rendering) + interval : int, optional + Delay between frames in milliseconds. + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + minwidth = 0.2 + maxwidth = 5.0 + + trajectory = jumps.trajectory + + class LabelItems: + + def __init__(self, labels, coords): + self.labels = labels + self.coords = coords + + def items(self): + yield from zip(self.labels, self.coords) + + coords = jumps.sites.frac_coords + lattice = trajectory.get_lattice() + + color_from = colormaps['Set1'].colors # type: ignore + color_to = colormaps['Pastel1'].colors # type: ignore + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + xyz_labels = LabelItems('OABC', [[-0.1, -0.1, -0.1], [1.1, -0.1, -0.1], + [-0.1, 1.1, -0.1], [-0.1, -0.1, 1.1]]) + + plotter.plot_lattice_vectors(lattice, ax=ax, linewidth=1) + + plotter.plot_labels(xyz_labels, + lattice=lattice, + ax=ax, + color='green', + size=12) + + assert len(ax.collections) == 0 + plotter.plot_points(coords, + lattice=lattice, + ax=ax, + s=50, + color='white', + edgecolor='black') + points = ax.collections + + events = jumps.data.sort_values('start time', ignore_index=True) + + for _, event in events.iterrows(): + site_i = event['start site'] + site_j = event['destination site'] + + coord_i = coords[site_i] + coord_j = coords[site_j] + + lw = 0 + + _, image = lattice.get_distance_and_image(coord_i, coord_j) + + line = [coord_i, coord_j + image] + + plotter.plot_path(line, + lattice=lattice, + ax=ax, + color='red', + linewidth=lw) + + lines = ax.lines[3:] + + ax.set( + title='Jumps between sites', + xlabel="x' (Å)", + ylabel="y' (Å)", + zlabel="z' (Å)", + ) + + ax.set_aspect('equal') # only auto is supported + + def update(frame_no): + t_frame = t_start + (frame_no * skip) + + for i, event in events.iterrows(): + + if event['start time'] > t_frame: + break + + lw = max(maxwidth - decay * (t_frame - event['start time']), + minwidth) + + line = lines[i] + line.set_color('red') + line.set_linewidth(lw) + + points[event['start site']].set_facecolor( + color_from[event['atom index'] % len(color_from)]) + points[event['destination site']].set_facecolor( + color_to[event['atom index'] % len(color_to)]) + + start_time = event['start time'] + ax.set_title(f'T: {t_frame} | Next jump: {start_time}') + + n_frames = int((t_stop - t_start) / skip) + + return animation.FuncAnimation(fig=fig, + func=update, + frames=n_frames, + interval=interval, + repeat=False) diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_distance.py b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py new file mode 100644 index 00000000..ba6b1057 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_vs_distance( + *, + jumps: Jumps, + jump_res: float = 0.1, +) -> plt.Figure: + """Plot jumps vs. distance histogram. + + Parameters + ---------- + jumps : Jumps + Input data + jump_res : float, optional + Resolution of the bins in Angstrom + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + sites = jumps.sites + + trajectory = jumps.trajectory + lattice = trajectory.get_lattice() + + pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords) + + bin_max = (1 + pdist.max() // jump_res) * jump_res + n_bins = int(bin_max / jump_res) + 1 + x = np.linspace(0, bin_max, n_bins) + counts = np.zeros_like(x) + + bin_idx = np.digitize(pdist, bins=x) + for idx, n in zip(bin_idx.flatten(), jumps.matrix().flatten()): + counts[idx] += n + + fig, ax = plt.subplots() + + ax.bar(x, counts, width=(jump_res * 0.8)) + + ax.set(title='Jumps vs. Distance', + xlabel='Distance (Å)', + ylabel='Number of jumps') + + return fig diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_time.py b/src/gemdat/plots/matplotlib/_jumps_vs_time.py new file mode 100644 index 00000000..d0ac2a21 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_jumps_vs_time.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure: + """Plot jumps vs. time histogram. + + Parameters + ---------- + jumps : Jumps + Input data + binsize : int, optional + Width of each bin in number of time steps + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + + trajectory = jumps.trajectory + + n_steps = len(trajectory) + bins = np.arange(0, n_steps + binsize, binsize) + + fig, ax = plt.subplots() + + ax.hist(jumps.data['stop time'], bins=bins, width=0.8 * binsize) + + ax.set(title='Jumps vs. time', + xlabel='Time (steps)', + ylabel='Number of jumps') + + return fig diff --git a/src/gemdat/plots/matplotlib/_msd_per_element.py b/src/gemdat/plots/matplotlib/_msd_per_element.py new file mode 100644 index 00000000..f60d2b3e --- /dev/null +++ b/src/gemdat/plots/matplotlib/_msd_per_element.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np + +from gemdat.trajectory import Trajectory + + +def msd_per_element( + *, + trajectory: Trajectory, + show_traces: bool = True, +) -> plt.Figure: + """Plot mean squared displacement per element. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory + show_traces : bool + If True, show individual traces for each element + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + species = list(set(trajectory.species)) + + fig, ax = plt.subplots() + + # Since we want to plot in picosecond, we convert the time units + time_ps = trajectory.time_step * 1e12 + t_values = np.arange(len(trajectory)) * time_ps + + for sp in species: + traj = trajectory.filter(sp.symbol) + msd = traj.mean_squared_displacement() + + msd_mean = np.mean(msd, axis=0) + msd_std = np.std(msd, axis=0) + + ax.plot(t_values, msd_mean, lw=0.5, label=f'{sp.symbol} mean') + + last_color = ax.lines[-1].get_color() + + if show_traces: + for i, traj in enumerate(msd): + label = f'{sp.symbol} trajectories' if (i == 0) else None + ax.plot(t_values, traj, lw=0.1, c=last_color, label=label) + + ax.fill_between(t_values, + msd_mean - msd_std, + msd_mean + msd_std, + color=last_color, + alpha=0.2, + label=f'{sp.symbol} std') + + ax.legend() + ax.set(title='Mean squared displacement per element', + xlabel='Time lag (ps)', + ylabel='MSD (Å$^2$)') + + return fig diff --git a/src/gemdat/plots/matplotlib/_orientations.py b/src/gemdat/plots/matplotlib/_orientations.py deleted file mode 100644 index e1647bea..00000000 --- a/src/gemdat/plots/matplotlib/_orientations.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -import matplotlib.pyplot as plt -import numpy as np -from scipy.optimize import curve_fit -from scipy.stats import skewnorm - -from gemdat.orientations import ( - Orientations, - calculate_spherical_areas, -) - - -def rectilinear(*, - orientations: Orientations, - shape: tuple[int, int] = (90, 360), - normalize_histo: bool = True) -> plt.Figure: - """Plot a rectilinear projection of a spherical function. This function - uses the transformed trajectory. - - Parameters - ---------- - orientations : Orientations - The unit vector trajectories - shape : tuple - The shape of the spherical sector in which the trajectory is plotted - normalize_histo : bool, optional - If True, normalize the histogram by the area of the bins, by default True - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - # Convert the vectors to spherical coordinates - az, el, _ = orientations.vectors_spherical.T - az = az.flatten() - el = el.flatten() - - hist, xedges, yedges = np.histogram2d(el, az, shape) - - if normalize_histo: - # Normalize by the area of the bins - areas = calculate_spherical_areas(shape) - hist = np.divide(hist, areas) - # Drop the bins at the poles where normalization is not possible - hist = hist[1:-1, :] - - values = hist.T - axis_phi, axis_theta = values.shape - - phi = np.linspace(0, 360, axis_phi) - theta = np.linspace(0, 180, axis_theta) - - theta, phi = np.meshgrid(theta, phi) - - fig, ax = plt.subplots(subplot_kw=dict(projection='rectilinear')) - cs = ax.contourf(phi, theta, values, cmap='viridis') - ax.set_yticks(np.arange(0, 190, 45)) - ax.set_xticks(np.arange(0, 370, 45)) - - ax.set_xlabel(r'azimuthal angle φ $[\degree$]') - ax.set_ylabel(r'elevation θ $[\degree$]') - - ax.grid(visible=True) - cbar = fig.colorbar(cs, label='areal probability', format='') - - # Rotate the colorbar label by 180 degrees - cbar.ax.yaxis.set_label_coords(2.5, - 0.5) # Adjust the position of the label - cbar.set_label('areal probability', rotation=270, labelpad=15) - return fig - - -def bond_length_distribution(*, - orientations: Orientations, - bins: int = 1000) -> plt.Figure: - """Plot the bond length probability distribution. - - Parameters - ---------- - orientations : Orientations - The unit vector trajectories - bins : int, optional - The number of bins, by default 1000 - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - *_, bond_lengths = orientations.vectors_spherical.T - bond_lengths = bond_lengths.flatten() - - fig, ax = plt.subplots() - - # Plot the normalized histogram - hist, edges = np.histogram(bond_lengths, bins=bins, density=True) - bin_centers = (edges[:-1] + edges[1:]) / 2 - - # Fit a skewed Gaussian distribution to the orientations - params, covariance = curve_fit(skewnorm.pdf, - bin_centers, - hist, - p0=[1.5, 1, 1.5]) - - # Create a new function using the fitted parameters - def _skewnorm_fit(x): - return skewnorm.pdf(x, *params) - - # Plot the histogram - ax.hist(bond_lengths, - bins=bins, - density=True, - color='blue', - alpha=0.7, - label='Data') - - # Plot the fitted skewed Gaussian distribution - x_fit = np.linspace(min(bin_centers), max(bin_centers), 1000) - ax.plot(x_fit, _skewnorm_fit(x_fit), 'r-', label='Skewed Gaussian Fit') - - ax.set_xlabel(r'Bond length $[\AA]$') - ax.set_ylabel(r'Probability density $[\AA^{-1}]$') - ax.set_title('Bond Length Probability Distribution') - ax.legend() - ax.grid(True) - - return fig - - -def autocorrelation( - *, - orientations: Orientations, -) -> plt.Figure: - """Plot the autocorrelation function of the unit vectors series. - - Parameters - ---------- - orientations : Orientations - The unit vector trajectories - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - ac = orientations.autocorrelation() - ac_std = ac.std(axis=0) - ac_mean = ac.mean(axis=0) - - # Since we want to plot in picosecond, we convert the time units - time_ps = orientations._time_step * 1e12 - tgrid = np.arange(ac_mean.shape[0]) * time_ps - - # and now we can plot the autocorrelation function - fig, ax = plt.subplots() - - ax.plot(tgrid, ac_mean, label='FFT-Autocorrelation') - ax.fill_between(tgrid, ac_mean - ac_std, ac_mean + ac_std, alpha=0.2) - ax.set_xlabel('Time lag [ps]') - ax.set_ylabel('Autocorrelation') - - return fig diff --git a/src/gemdat/plots/matplotlib/_path_on_grid.py b/src/gemdat/plots/matplotlib/_path_on_grid.py new file mode 100644 index 00000000..1412942c --- /dev/null +++ b/src/gemdat/plots/matplotlib/_path_on_grid.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt + +from gemdat.path import Pathway + + +def path_on_grid(path: Pathway) -> plt.Figure: + """Plot the 3d coordinates of the points that define a path. + + Parameters + ---------- + path : Pathway + Pathway to plot + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + # Create a colormap to visualize the path + colormap = plt.get_cmap() + normalize = plt.Normalize(0, len(path.energy)) + + fig, ax = plt.subplots() + ax = fig.add_subplot(111, projection='3d') + + path_x, path_y, path_z = zip(*path.sites) + + for i in range(len(path.energy) - 1): + ax.plot(path_x[i:i + 1], + path_y[i:i + 1], + path_z[i:i + 1], + color=colormap(normalize(i)), + marker='o', + linestyle='-') + + ax.set_xlabel('X') + ax.set_ylabel('Y') + sm = plt.cm.ScalarMappable(cmap=colormap, norm=normalize) + sm.set_array([]) + cbar = plt.colorbar(sm, ax=ax) + cbar.set_label('Steps') + + return fig diff --git a/src/gemdat/plots/matplotlib/_rdf.py b/src/gemdat/plots/matplotlib/_radial_distribution.py similarity index 95% rename from src/gemdat/plots/matplotlib/_rdf.py rename to src/gemdat/plots/matplotlib/_radial_distribution.py index 22564e7f..fc75d37b 100644 --- a/src/gemdat/plots/matplotlib/_rdf.py +++ b/src/gemdat/plots/matplotlib/_radial_distribution.py @@ -30,7 +30,7 @@ def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure: ax.legend() ax.set(title=f'Radial distribution function ({states})', - xlabel='Distance (Ang)', + xlabel='Distance (Å)', ylabel='Counts') return fig diff --git a/src/gemdat/plots/matplotlib/_rectilinear.py b/src/gemdat/plots/matplotlib/_rectilinear.py new file mode 100644 index 00000000..fd12001b --- /dev/null +++ b/src/gemdat/plots/matplotlib/_rectilinear.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np + +from gemdat.orientations import ( + Orientations, + calculate_spherical_areas, +) + + +def rectilinear(*, + orientations: Orientations, + shape: tuple[int, int] = (90, 360), + normalize_histo: bool = True) -> plt.Figure: + """Plot a rectilinear projection of a spherical function. This function + uses the transformed trajectory. + + Parameters + ---------- + orientations : Orientations + The unit vector trajectories + shape : tuple + The shape of the spherical sector in which the trajectory is plotted + normalize_histo : bool, optional + If True, normalize the histogram by the area of the bins, by default True + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + # Convert the vectors to spherical coordinates + az, el, _ = orientations.vectors_spherical.T + az = az.flatten() + el = el.flatten() + + hist, xedges, yedges = np.histogram2d(el, az, shape) + + if normalize_histo: + # Normalize by the area of the bins + areas = calculate_spherical_areas(shape) + hist = np.divide(hist, areas) + # Drop the bins at the poles where normalization is not possible + hist = hist[1:-1, :] + + values = hist.T + axis_phi, axis_theta = values.shape + + phi = np.linspace(0, 360, axis_phi) + theta = np.linspace(0, 180, axis_theta) + + theta, phi = np.meshgrid(theta, phi) + + fig, ax = plt.subplots(subplot_kw=dict(projection='rectilinear')) + cs = ax.contourf(phi, theta, values, cmap='viridis') + ax.set_yticks(np.arange(0, 190, 45)) + ax.set_xticks(np.arange(0, 370, 45)) + + ax.set_xlabel('Azimuthal angle φ (°)') + ax.set_ylabel('Elevation θ (°)') + + ax.grid(visible=True) + cbar = fig.colorbar(cs, label='Areal probability', format='') + + # Rotate the colorbar label by 180 degrees + cbar.ax.yaxis.set_label_coords(2.5, + 0.5) # Adjust the position of the label + cbar.set_label('Areal probability', rotation=270, labelpad=15) + return fig diff --git a/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py new file mode 100644 index 00000000..64e12cd2 --- /dev/null +++ b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +from scipy import stats + +from gemdat.simulation_metrics import SimulationMetrics +from gemdat.trajectory import Trajectory + + +def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure: + """Plot histogram of vibrational amplitudes with fitted Gaussian. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + metrics = SimulationMetrics(trajectory) + + fig, ax = plt.subplots() + ax.hist(metrics.amplitudes(), bins=100, density=True) + + x = np.linspace(-2, 2, 100) + y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude()) + ax.plot(x, y_gauss, 'r') + + ax.set(title='Histogram of vibrational amplitudes with fitted Gaussian', + xlabel='Amplitude (Å)', + ylabel='Occurrence (a.u.)') + + return fig diff --git a/src/gemdat/plots/plotly/__init__.py b/src/gemdat/plots/plotly/__init__.py index 22fc1156..9f94a00f 100644 --- a/src/gemdat/plots/plotly/__init__.py +++ b/src/gemdat/plots/plotly/__init__.py @@ -2,23 +2,17 @@ from __future__ import annotations from ._density import density -from ._displacements import ( - displacement_histogram, - displacement_per_atom, - displacement_per_element, - msd_per_element, -) -from ._jumps import ( - collective_jumps, - jumps_3d, - jumps_vs_distance, - jumps_vs_time, -) +from ._displacement_histogram import displacement_histogram +from ._displacement_per_atom import displacement_per_atom +from ._displacement_per_element import displacement_per_element +from ._msd_per_element import msd_per_element +from ._collective_jumps import collective_jumps +from ._jumps_3d import jumps_3d +from ._jumps_vs_distance import jumps_vs_distance +from ._jumps_vs_time import jumps_vs_time from ._plot3d import plot_3d -from ._vibration import ( - frequency_vs_occurence, - vibrational_amplitudes, -) +from ._frequency_vs_occurence import frequency_vs_occurence +from ._vibrational_amplitudes import vibrational_amplitudes __all__ = [ 'collective_jumps', diff --git a/src/gemdat/plots/plotly/_collective_jumps.py b/src/gemdat/plots/plotly/_collective_jumps.py new file mode 100644 index 00000000..6bbc310a --- /dev/null +++ b/src/gemdat/plots/plotly/_collective_jumps.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import plotly.express as px +import plotly.graph_objects as go + +if TYPE_CHECKING: + from gemdat import Jumps + + +def collective_jumps(*, jumps: Jumps) -> go.Figure: + """Plot collective jumps per jump-type combination. + + Parameters + ---------- + jumps : Jumps + Input data + + Returns + ------- + fig : plotly.graph_objects.Figure + Output figure + """ + collective = jumps.collective() + matrix = collective.site_pair_count_matrix() + + fig = px.imshow(matrix) + + labels = collective.site_pair_count_matrix_labels() + + ticks = list(range(len(labels))) + + fig.update_layout(xaxis={ + 'tickmode': 'array', + 'tickvals': ticks, + 'ticktext': labels + }, + yaxis={ + 'tickmode': 'array', + 'tickvals': ticks, + 'ticktext': labels + }, + title='Cooperative jumps per jump-type combination') + + return fig diff --git a/src/gemdat/plots/plotly/_displacement_histogram.py b/src/gemdat/plots/plotly/_displacement_histogram.py new file mode 100644 index 00000000..3ea1d608 --- /dev/null +++ b/src/gemdat/plots/plotly/_displacement_histogram.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + +from gemdat.trajectory import Trajectory + + +def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame: + """_trajectory_to_dataframe. + + Parameters + ---------- + trajectory : Trajectory + trajectory + + Returns + ------- + pd.DataFrame + """ + data = [] + for specie, distance in zip( + trajectory.species, + trajectory.distances_from_base_position()[:, -1]): + data.append((specie, round(distance))) + + df = pd.DataFrame(columns=['Element', 'Displacement'], data=data) + df = df.groupby(['Displacement', 'Element' + ]).size().reset_index().rename(columns={0: 'count'}) + return df + + +def displacement_histogram(trajectory: Trajectory, + n_parts: int = 1) -> go.Figure: + """Plot histogram of total displacement at final timestep. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + n_parts : int + Plot error bars by dividing data into n parts + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + + if n_parts == 1: + df = _trajectory_to_dataframe(trajectory) + + fig = px.bar(df, + x='Displacement', + y='count', + color='Element', + barmode='stack') + + fig.update_layout(title='Displacement per element', + xaxis_title='Displacement (Å)', + yaxis_title='Nr. of atoms') + else: + interval = np.linspace(0, len(trajectory) - 1, n_parts + 1) + dfs = [ + _trajectory_to_dataframe(part) + for part in trajectory.split(n_parts) + ] + + all_df = pd.concat(dfs) + + # Get the mean and standard deviation + grouped = all_df.groupby(['Displacement', 'Element']) + mean = grouped.mean().reset_index().rename(columns={'count': 'mean'}) + std = grouped.std().reset_index().rename(columns={'count': 'std'}) + df = mean.merge(std, how='inner') + + fig = px.bar(df, + x='Displacement', + y='mean', + color='Element', + error_y='std', + barmode='group') + + fig.update_layout( + title= + f'Displacement per element after {int(interval[1]-interval[0])} timesteps', + xaxis_title='Displacement (Å)', + yaxis_title='Nr. of atoms') + + return fig diff --git a/src/gemdat/plots/plotly/_displacement_per_atom.py b/src/gemdat/plots/plotly/_displacement_per_atom.py new file mode 100644 index 00000000..b0213f6b --- /dev/null +++ b/src/gemdat/plots/plotly/_displacement_per_atom.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import plotly.graph_objects as go + +from gemdat.trajectory import Trajectory + + +def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure: + """Plot displacement per atom. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + + fig = go.Figure() + + distances = [dist for dist in trajectory.distances_from_base_position()] + + for i, distance in enumerate(distances): + fig.add_trace( + go.Scatter(y=distance, + name=i, + mode='lines', + line={'width': 1}, + showlegend=False)) + + fig.update_layout(title='Displacement per atom', + xaxis_title='Time step', + yaxis_title='Displacement (Å)') + + return fig diff --git a/src/gemdat/plots/plotly/_displacement_per_element.py b/src/gemdat/plots/plotly/_displacement_per_element.py new file mode 100644 index 00000000..3b3068c8 --- /dev/null +++ b/src/gemdat/plots/plotly/_displacement_per_element.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import plotly.graph_objects as go + +from gemdat.trajectory import Trajectory +from gemdat.plots._shared import _mean_displacements_per_element + + +def displacement_per_element(*, trajectory: Trajectory) -> go.Figure: + """Plot displacement per element. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + displacements = _mean_displacements_per_element(trajectory) + + fig = go.Figure() + + for symbol, (mean, std) in displacements.items(): + fig.add_trace( + go.Scatter(y=mean, + name=symbol + ' + std', + mode='lines', + line={'width': 3}, + legendgroup=symbol)) + fig.add_trace( + go.Scatter(y=mean + std, + name=symbol + ' + std', + mode='lines', + line={'width': 0}, + legendgroup=symbol, + showlegend=False)) + fig.add_trace( + go.Scatter(y=mean - std, + name=symbol + ' + std', + mode='lines', + line={'width': 0}, + legendgroup=symbol, + showlegend=False, + fill='tonexty')) + + fig.update_layout(title='Displacement per element', + xaxis_title='Time step', + yaxis_title='Displacement (Å)') + + return fig diff --git a/src/gemdat/plots/plotly/_displacements.py b/src/gemdat/plots/plotly/_displacements.py deleted file mode 100644 index 0ed56ed6..00000000 --- a/src/gemdat/plots/plotly/_displacements.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -from collections import defaultdict - -import numpy as np -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go - -from gemdat.trajectory import Trajectory - - -def displacement_per_atom(*, trajectory: Trajectory) -> go.Figure: - """Plot displacement per atom. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - - fig = go.Figure() - - distances = [dist for dist in trajectory.distances_from_base_position()] - - for i, distance in enumerate(distances): - fig.add_trace( - go.Scatter(y=distance, - name=i, - mode='lines', - line={'width': 1}, - showlegend=False)) - - fig.update_layout(title='Displacement per atom', - xaxis_title='Time step', - yaxis_title='Displacement (Angstrom)') - - return fig - - -def displacement_per_element(*, trajectory: Trajectory) -> go.Figure: - """Plot displacement per element. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - - fig = go.Figure() - - grouped = defaultdict(list) - - species = trajectory.species - - for sp, distances in zip(species, - trajectory.distances_from_base_position()): - grouped[sp.symbol].append(distances) - - for symbol, distances in grouped.items(): - mean_disp = np.mean(distances, axis=0) - std_disp = np.std(distances, axis=0) - fig.add_trace( - go.Scatter(y=mean_disp, - name=symbol + ' + std', - mode='lines', - line={'width': 3}, - legendgroup=symbol)) - fig.add_trace( - go.Scatter(y=mean_disp + std_disp, - name=symbol + ' + std', - mode='lines', - line={'width': 0}, - legendgroup=symbol, - showlegend=False)) - fig.add_trace( - go.Scatter(y=mean_disp - std_disp, - name=symbol + ' + std', - mode='lines', - line={'width': 0}, - legendgroup=symbol, - showlegend=False, - fill='tonexty')) - - fig.update_layout(title='Displacement per element', - xaxis_title='Time step', - yaxis_title='Displacement (Angstrom)') - - return fig - - -def msd_per_element(*, trajectory: Trajectory) -> go.Figure: - """Plot mean squared displacement per element. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - - fig = go.Figure() - - species = list(set(trajectory.species)) - - # Since we want to plot in picosecond, we convert the time units - time_ps = trajectory.time_step * 1e12 - - for sp in species: - traj = trajectory.filter(sp.symbol) - - msd = traj.mean_squared_displacement() - msd_mean = np.mean(msd, axis=0) - msd_std = np.std(msd, axis=0) - t_values = np.arange(len(msd_mean)) * time_ps - - fig.add_trace( - go.Scatter(x=t_values, - y=msd_mean, - error_y=dict(type='data', - array=msd_std, - width=0.1, - thickness=0.1), - name=sp.symbol, - mode='lines', - line={'width': 3}, - legendgroup=sp.symbol)) - - fig.update_layout(title='Mean squared displacement per element', - xaxis_title='Time lag [ps]', - yaxis_title='MSD (Angstrom2)') - - return fig - - -def _trajectory_to_dataframe(trajectory: Trajectory) -> pd.DataFrame: - """_trajectory_to_dataframe. - - Parameters - ---------- - trajectory : Trajectory - trajectory - - Returns - ------- - pd.DataFrame - """ - data = [] - for specie, distance in zip( - trajectory.species, - trajectory.distances_from_base_position()[:, -1]): - data.append((specie, round(distance))) - - df = pd.DataFrame(columns=['Element', 'Displacement'], data=data) - df = df.groupby(['Displacement', 'Element' - ]).size().reset_index().rename(columns={0: 'count'}) - return df - - -def displacement_histogram(trajectory: Trajectory, - n_parts: int = 1) -> go.Figure: - """Plot histogram of total displacement at final timestep. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - n_parts : int - Plot error bars by dividing data into n parts - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - - if n_parts == 1: - df = _trajectory_to_dataframe(trajectory) - - fig = px.bar(df, - x='Displacement', - y='count', - color='Element', - barmode='stack') - - fig.update_layout(title='Displacement per element', - xaxis_title='Displacement (Angstrom)', - yaxis_title='Nr. of atoms') - else: - interval = np.linspace(0, len(trajectory) - 1, n_parts + 1) - dfs = [ - _trajectory_to_dataframe(part) - for part in trajectory.split(n_parts) - ] - - all_df = pd.concat(dfs) - - # Get the mean and standard deviation - grouped = all_df.groupby(['Displacement', 'Element']) - mean = grouped.mean().reset_index().rename(columns={'count': 'mean'}) - std = grouped.std().reset_index().rename(columns={'count': 'std'}) - df = mean.merge(std, how='inner') - - fig = px.bar(df, - x='Displacement', - y='mean', - color='Element', - error_y='std', - barmode='group') - - fig.update_layout( - title= - f'Displacement per element after {int(interval[1]-interval[0])} timesteps', - xaxis_title='Displacement (Angstrom)', - yaxis_title='Nr. of atoms') - - return fig diff --git a/src/gemdat/plots/plotly/_frequency_vs_occurence.py b/src/gemdat/plots/plotly/_frequency_vs_occurence.py new file mode 100644 index 00000000..0984ee7f --- /dev/null +++ b/src/gemdat/plots/plotly/_frequency_vs_occurence.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import numpy as np +import plotly.graph_objects as go + +from gemdat.simulation_metrics import SimulationMetrics +from gemdat.trajectory import Trajectory + + +def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure: + """Plot attempt frequency vs occurence. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory, i.e. for the diffusing atom + + Returns + ------- + fig : go.figure.Figure + Output figure + """ + metrics = SimulationMetrics(trajectory) + speed = metrics.speed() + + length = speed.shape[1] + half_length = length // 2 + 1 + + trans = np.fft.fft(speed) + + two_sided = np.abs(trans / length) + one_sided = two_sided[:, :half_length] + + fig = go.Figure() + + f = trajectory.sampling_frequency * np.arange(half_length) / length + + sum_freqs = np.sum(one_sided, axis=0) + smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51 + fig.add_trace( + go.Scatter(y=smoothed, + x=f, + mode='lines', + line={ + 'width': 3, + 'color': 'blue' + }, + showlegend=False)) + + y_max = np.max(sum_freqs) + + attempt_freq, attempt_freq_std = metrics.attempt_frequency() + + if attempt_freq: + fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'}) + if attempt_freq and attempt_freq_std: + fig.add_vline(x=attempt_freq + attempt_freq_std, + line={ + 'width': 2, + 'color': 'red', + 'dash': 'dash' + }) + fig.add_vline(x=attempt_freq - attempt_freq_std, + line={ + 'width': 2, + 'color': 'red', + 'dash': 'dash' + }) + + fig.update_layout(title='Frequency vs Occurence', + xaxis_title='Frequency (Hz)', + yaxis_title='Occurrence (a.u.)', + xaxis_range=[-0.1e13, 2.5e13], + yaxis_range=[0, y_max], + width=600, + height=500) + + return fig diff --git a/src/gemdat/plots/plotly/_jumps.py b/src/gemdat/plots/plotly/_jumps.py deleted file mode 100644 index 8c17ae76..00000000 --- a/src/gemdat/plots/plotly/_jumps.py +++ /dev/null @@ -1,181 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go - -if TYPE_CHECKING: - - from gemdat import Jumps - - -def jumps_vs_distance(*, - jumps: Jumps, - jump_res: float = 0.1, - n_parts: int = 1) -> go.Figure: - """Plot jumps vs. distance histogram. - - Parameters - ---------- - jumps : Jumps - Input jumps data - jump_res : float, optional - Resolution of the bins in Angstrom - n_parts : int - Number of parts for error analysis - - Returns - ------- - fig : plotly.graph_objects.Figure - Output figure - """ - sites = jumps.sites - trajectory = jumps.trajectory - lattice = trajectory.get_lattice() - - pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords) - - bin_max = (1 + pdist.max() // jump_res) * jump_res - n_bins = int(bin_max / jump_res) + 1 - x = np.linspace(0, bin_max, n_bins) - - bin_idx = np.digitize(pdist, bins=x) - data = [] - for transitions_part in jumps.split(n_parts=n_parts): - counts = np.zeros_like(x) - for idx, n in zip(bin_idx.flatten(), - transitions_part.matrix().flatten()): - counts[idx] += n - for idx in range(n_bins): - if counts[idx] > 0: - data.append((x[idx], counts[idx])) - - df = pd.DataFrame(data=data, columns=['Displacement', 'count']) - - grouped = df.groupby(['Displacement']) - mean = grouped.mean().reset_index().rename(columns={'count': 'mean'}) - std = grouped.std().reset_index().rename(columns={'count': 'std'}) - df = mean.merge(std, how='inner') - - if n_parts == 1: - fig = px.bar(df, x='Displacement', y='mean', barmode='stack') - else: - fig = px.bar(df, - x='Displacement', - y='mean', - error_y='std', - barmode='stack') - - fig.update_layout(title='Jumps vs. Distance', - xaxis_title='Distance (Angstrom)', - yaxis_title='Number of jumps') - - return fig - - -def jumps_vs_time(*, - jumps: Jumps, - bins: int = 8, - n_parts: int = 1) -> go.Figure: - """Plot jumps vs. distance histogram. - - Parameters - ---------- - jumps : Jumps - Input jumps data - bins : int, optional - Number of bins - n_parts : int - Number of parts for error analysis - - Returns - ------- - fig : matplotlib.figure.Figure - Output figure - """ - maxlen = len(jumps.trajectory) / n_parts - binsize = maxlen / bins + 1 - data = [] - - for jumps_part in jumps.split(n_parts=n_parts): - data.append( - np.histogram(jumps_part.data['start time'], - bins=bins, - range=(0., maxlen))[0]) - - df = pd.DataFrame(data=data) - columns = [binsize / 2 + binsize * col for col in range(bins)] - - mean = [df[col].mean() for col in df.columns] - std = [df[col].std() for col in df.columns] - - df = pd.DataFrame(data=zip(columns, mean, std), - columns=['time', 'count', 'std']) - - if n_parts > 1: - fig = px.bar(df, x='time', y='count', error_y='std') - else: - fig = px.bar(df, x='time', y='count') - - fig.update_layout(bargap=0.2, - title='Jumps vs. time', - xaxis_title='Time (steps)', - yaxis_title='Number of jumps') - - return fig - - -def collective_jumps(*, jumps: Jumps) -> go.Figure: - """Plot collective jumps per jump-type combination. - - Parameters - ---------- - jumps : Jumps - Input data - - Returns - ------- - fig : plotly.graph_objects.Figure - Output figure - """ - - matrix = jumps.collective().site_pair_count_matrix() - labels = jumps.collective().site_pair_count_matrix_labels() - - fig = px.imshow(matrix) - - ticks = list(range(len(labels))) - - fig.update_layout(xaxis={ - 'tickmode': 'array', - 'tickvals': ticks, - 'ticktext': labels - }, - yaxis={ - 'tickmode': 'array', - 'tickvals': ticks, - 'ticktext': labels - }, - title='Cooperative jumps per jump-type combination') - - return fig - - -def jumps_3d(*, jumps: Jumps) -> go.Figure: - """Plot jumps in 3D. - - Parameters - ---------- - jumps : Jumps - Input data - - Returns - ------- - fig : plotly.graph_objects.Figure - Output figure - """ - from ._plot3d import plot_3d - return plot_3d(jumps=jumps, structure=jumps.sites) diff --git a/src/gemdat/plots/plotly/_jumps_3d.py b/src/gemdat/plots/plotly/_jumps_3d.py new file mode 100644 index 00000000..fcad5ae5 --- /dev/null +++ b/src/gemdat/plots/plotly/_jumps_3d.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import plotly.graph_objects as go + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_3d(*, jumps: Jumps) -> go.Figure: + """Plot jumps in 3D. + + Parameters + ---------- + jumps : Jumps + Input data + + Returns + ------- + fig : plotly.graph_objects.Figure + Output figure + """ + from ._plot3d import plot_3d + return plot_3d(jumps=jumps, structure=jumps.sites) diff --git a/src/gemdat/plots/plotly/_jumps_vs_distance.py b/src/gemdat/plots/plotly/_jumps_vs_distance.py new file mode 100644 index 00000000..3570ba77 --- /dev/null +++ b/src/gemdat/plots/plotly/_jumps_vs_distance.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_vs_distance(*, + jumps: Jumps, + jump_res: float = 0.1, + n_parts: int = 1) -> go.Figure: + """Plot jumps vs. distance histogram. + + Parameters + ---------- + jumps : Jumps + Input jumps data + jump_res : float, optional + Resolution of the bins in Angstrom + n_parts : int + Number of parts for error analysis + + Returns + ------- + fig : plotly.graph_objects.Figure + Output figure + """ + sites = jumps.sites + trajectory = jumps.trajectory + lattice = trajectory.get_lattice() + + pdist = lattice.get_all_distances(sites.frac_coords, sites.frac_coords) + + bin_max = (1 + pdist.max() // jump_res) * jump_res + n_bins = int(bin_max / jump_res) + 1 + x = np.linspace(0, bin_max, n_bins) + + bin_idx = np.digitize(pdist, bins=x) + data = [] + for transitions_part in jumps.split(n_parts=n_parts): + counts = np.zeros_like(x) + for idx, n in zip(bin_idx.flatten(), + transitions_part.matrix().flatten()): + counts[idx] += n + for idx in range(n_bins): + if counts[idx] > 0: + data.append((x[idx], counts[idx])) + + df = pd.DataFrame(data=data, columns=['Displacement', 'count']) + + grouped = df.groupby(['Displacement']) + mean = grouped.mean().reset_index().rename(columns={'count': 'mean'}) + std = grouped.std().reset_index().rename(columns={'count': 'std'}) + df = mean.merge(std, how='inner') + + if n_parts == 1: + fig = px.bar(df, x='Displacement', y='mean', barmode='stack') + else: + fig = px.bar(df, + x='Displacement', + y='mean', + error_y='std', + barmode='stack') + + fig.update_layout(title='Jumps vs. Distance', + xaxis_title='Distance (Å)', + yaxis_title='Number of jumps') + + return fig diff --git a/src/gemdat/plots/plotly/_jumps_vs_time.py b/src/gemdat/plots/plotly/_jumps_vs_time.py new file mode 100644 index 00000000..8116d1c6 --- /dev/null +++ b/src/gemdat/plots/plotly/_jumps_vs_time.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + +if TYPE_CHECKING: + from gemdat import Jumps + + +def jumps_vs_time(*, + jumps: Jumps, + bins: int = 8, + n_parts: int = 1) -> go.Figure: + """Plot jumps vs. distance histogram. + + Parameters + ---------- + jumps : Jumps + Input jumps data + bins : int, optional + Number of bins + n_parts : int + Number of parts for error analysis + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + maxlen = len(jumps.trajectory) / n_parts + binsize = maxlen / bins + 1 + data = [] + + for jumps_part in jumps.split(n_parts=n_parts): + data.append( + np.histogram(jumps_part.data['start time'], + bins=bins, + range=(0., maxlen))[0]) + + df = pd.DataFrame(data=data) + columns = [binsize / 2 + binsize * col for col in range(bins)] + + mean = [df[col].mean() for col in df.columns] + std = [df[col].std() for col in df.columns] + + df = pd.DataFrame(data=zip(columns, mean, std), + columns=['time', 'count', 'std']) + + if n_parts > 1: + fig = px.bar(df, x='time', y='count', error_y='std') + else: + fig = px.bar(df, x='time', y='count') + + fig.update_layout(bargap=0.2, + title='Jumps vs. time', + xaxis_title='Time (steps)', + yaxis_title='Number of jumps') + + return fig diff --git a/src/gemdat/plots/plotly/_msd_per_element.py b/src/gemdat/plots/plotly/_msd_per_element.py new file mode 100644 index 00000000..6b9ad729 --- /dev/null +++ b/src/gemdat/plots/plotly/_msd_per_element.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import numpy as np +import plotly.graph_objects as go + +from gemdat.trajectory import Trajectory + + +def msd_per_element(*, trajectory: Trajectory) -> go.Figure: + """Plot mean squared displacement per element. + + Parameters + ---------- + trajectory : Trajectory + Input trajectory + + Returns + ------- + fig : matplotlib.figure.Figure + Output figure + """ + + fig = go.Figure() + + species = list(set(trajectory.species)) + + # Since we want to plot in picosecond, we convert the time units + time_ps = trajectory.time_step * 1e12 + + for sp in species: + traj = trajectory.filter(sp.symbol) + + msd = traj.mean_squared_displacement() + msd_mean = np.mean(msd, axis=0) + msd_std = np.std(msd, axis=0) + t_values = np.arange(len(msd_mean)) * time_ps + + fig.add_trace( + go.Scatter(x=t_values, + y=msd_mean, + error_y=dict(type='data', + array=msd_std, + width=0.1, + thickness=0.1), + name=f'{sp.symbol} mean+std', + mode='lines', + line={'width': 3}, + legendgroup=sp.symbol)) + + fig.update_layout(title='Mean squared displacement per element', + xaxis_title='Time lag (ps)', + yaxis_title=r'MSD (Å2)') + + return fig diff --git a/src/gemdat/plots/plotly/_plot3d.py b/src/gemdat/plots/plotly/_plot3d.py index a919d0c9..4f5d6785 100644 --- a/src/gemdat/plots/plotly/_plot3d.py +++ b/src/gemdat/plots/plotly/_plot3d.py @@ -307,9 +307,9 @@ def update_layout(*, 'y': lattice.b * zoom, 'z': lattice.c * zoom, }, - 'xaxis_title': 'X (Ångstrom)', - 'yaxis_title': 'Y (Ångstrom)', - 'zaxis_title': 'Z (Ångstrom)' + 'xaxis_title': 'X (Å)', + 'yaxis_title': 'Y (Å)', + 'zaxis_title': 'Z (Å)' }, legend={ 'orientation': 'h', diff --git a/src/gemdat/plots/plotly/_vibration.py b/src/gemdat/plots/plotly/_vibrational_amplitudes.py similarity index 52% rename from src/gemdat/plots/plotly/_vibration.py rename to src/gemdat/plots/plotly/_vibrational_amplitudes.py index 05eb9796..d11d0e15 100644 --- a/src/gemdat/plots/plotly/_vibration.py +++ b/src/gemdat/plots/plotly/_vibrational_amplitudes.py @@ -10,77 +10,6 @@ from gemdat.trajectory import Trajectory -def frequency_vs_occurence(*, trajectory: Trajectory) -> go.Figure: - """Plot attempt frequency vs occurence. - - Parameters - ---------- - trajectory : Trajectory - Input trajectory, i.e. for the diffusing atom - - Returns - ------- - fig : go.figure.Figure - Output figure - """ - metrics = SimulationMetrics(trajectory) - speed = metrics.speed() - - length = speed.shape[1] - half_length = length // 2 + 1 - - trans = np.fft.fft(speed) - - two_sided = np.abs(trans / length) - one_sided = two_sided[:, :half_length] - - fig = go.Figure() - - f = trajectory.sampling_frequency * np.arange(half_length) / length - - sum_freqs = np.sum(one_sided, axis=0) - smoothed = np.convolve(sum_freqs, np.ones(51), 'same') / 51 - fig.add_trace( - go.Scatter(y=smoothed, - x=f, - mode='lines', - line={ - 'width': 3, - 'color': 'blue' - }, - showlegend=False)) - - y_max = np.max(sum_freqs) - - attempt_freq, attempt_freq_std = metrics.attempt_frequency() - - if attempt_freq: - fig.add_vline(x=attempt_freq, line={'width': 2, 'color': 'red'}) - if attempt_freq and attempt_freq_std: - fig.add_vline(x=attempt_freq + attempt_freq_std, - line={ - 'width': 2, - 'color': 'red', - 'dash': 'dash' - }) - fig.add_vline(x=attempt_freq - attempt_freq_std, - line={ - 'width': 2, - 'color': 'red', - 'dash': 'dash' - }) - - fig.update_layout(title='Frequency vs Occurence', - xaxis_title='Frequency (Hz)', - yaxis_title='Occurrence (a.u.)', - xaxis_range=[-0.1e13, 2.5e13], - yaxis_range=[0, y_max], - width=600, - height=500) - - return fig - - def vibrational_amplitudes(*, trajectory: Trajectory, bins: int = 50, @@ -149,7 +78,7 @@ def vibrational_amplitudes(*, fig.update_layout( title='Histogram of vibrational amplitudes with fitted Gaussian', - xaxis_title='Amplitude (Ångstrom)', + xaxis_title='Amplitude (Å)', yaxis_title='Occurrence (a.u.)') return fig diff --git a/tests/integration/baseline_images/plot_test/autocorrelation.png b/tests/integration/baseline_images/plot_test/autocorrelation.png new file mode 100644 index 00000000..c8afd429 Binary files /dev/null and b/tests/integration/baseline_images/plot_test/autocorrelation.png differ diff --git a/tests/integration/baseline_images/plot_test/bond_length_distribution.png b/tests/integration/baseline_images/plot_test/bond_length_distribution.png index 59475ea8..51320cc6 100644 Binary files a/tests/integration/baseline_images/plot_test/bond_length_distribution.png and b/tests/integration/baseline_images/plot_test/bond_length_distribution.png differ diff --git a/tests/integration/baseline_images/plot_test/displacement_histogram.png b/tests/integration/baseline_images/plot_test/displacement_histogram.png index 1fa9ae53..7e84f4ba 100644 Binary files a/tests/integration/baseline_images/plot_test/displacement_histogram.png and b/tests/integration/baseline_images/plot_test/displacement_histogram.png differ diff --git a/tests/integration/baseline_images/plot_test/displacement_per_atom.png b/tests/integration/baseline_images/plot_test/displacement_per_atom.png index 1f2b6e33..45e22f7c 100644 Binary files a/tests/integration/baseline_images/plot_test/displacement_per_atom.png and b/tests/integration/baseline_images/plot_test/displacement_per_atom.png differ diff --git a/tests/integration/baseline_images/plot_test/displacement_per_element.png b/tests/integration/baseline_images/plot_test/displacement_per_element.png index 406cfbc1..38a36fdb 100644 Binary files a/tests/integration/baseline_images/plot_test/displacement_per_element.png and b/tests/integration/baseline_images/plot_test/displacement_per_element.png differ diff --git a/tests/integration/baseline_images/plot_test/energy_along_path.png b/tests/integration/baseline_images/plot_test/energy_along_path.png new file mode 100644 index 00000000..fc8f3e28 Binary files /dev/null and b/tests/integration/baseline_images/plot_test/energy_along_path.png differ diff --git a/tests/integration/baseline_images/plot_test/jumps_3d.png b/tests/integration/baseline_images/plot_test/jumps_3d.png index acef3067..bbaa4597 100644 Binary files a/tests/integration/baseline_images/plot_test/jumps_3d.png and b/tests/integration/baseline_images/plot_test/jumps_3d.png differ diff --git a/tests/integration/baseline_images/plot_test/jumps_3d_animation.png b/tests/integration/baseline_images/plot_test/jumps_3d_animation.png index 2cc85801..5e0ff8b2 100644 Binary files a/tests/integration/baseline_images/plot_test/jumps_3d_animation.png and b/tests/integration/baseline_images/plot_test/jumps_3d_animation.png differ diff --git a/tests/integration/baseline_images/plot_test/jumps_vs_distance.png b/tests/integration/baseline_images/plot_test/jumps_vs_distance.png index 2eabe155..28d78e5c 100644 Binary files a/tests/integration/baseline_images/plot_test/jumps_vs_distance.png and b/tests/integration/baseline_images/plot_test/jumps_vs_distance.png differ diff --git a/tests/integration/baseline_images/plot_test/msd.png b/tests/integration/baseline_images/plot_test/msd.png deleted file mode 100644 index fd479f23..00000000 Binary files a/tests/integration/baseline_images/plot_test/msd.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/msd_per_element.png b/tests/integration/baseline_images/plot_test/msd_per_element.png new file mode 100644 index 00000000..89a8ed6a Binary files /dev/null and b/tests/integration/baseline_images/plot_test/msd_per_element.png differ diff --git a/tests/integration/baseline_images/plot_test/path_energy.png b/tests/integration/baseline_images/plot_test/path_energy.png deleted file mode 100644 index 08494d4b..00000000 Binary files a/tests/integration/baseline_images/plot_test/path_energy.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_1.png b/tests/integration/baseline_images/plot_test/radial_distribution_1.png new file mode 100644 index 00000000..6ace96a0 Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_1.png differ diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_2.png b/tests/integration/baseline_images/plot_test/radial_distribution_2.png new file mode 100644 index 00000000..8a0fab15 Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_2.png differ diff --git a/tests/integration/baseline_images/plot_test/radial_distribution_3.png b/tests/integration/baseline_images/plot_test/radial_distribution_3.png new file mode 100644 index 00000000..8daa103e Binary files /dev/null and b/tests/integration/baseline_images/plot_test/radial_distribution_3.png differ diff --git a/tests/integration/baseline_images/plot_test/rdf1.png b/tests/integration/baseline_images/plot_test/rdf1.png deleted file mode 100644 index ef78e5ee..00000000 Binary files a/tests/integration/baseline_images/plot_test/rdf1.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/rdf2.png b/tests/integration/baseline_images/plot_test/rdf2.png deleted file mode 100644 index 9e5a7011..00000000 Binary files a/tests/integration/baseline_images/plot_test/rdf2.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/rdf3.png b/tests/integration/baseline_images/plot_test/rdf3.png deleted file mode 100644 index 463e47df..00000000 Binary files a/tests/integration/baseline_images/plot_test/rdf3.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/rectilinear.png b/tests/integration/baseline_images/plot_test/rectilinear.png index 2fd5123b..fa21df9b 100644 Binary files a/tests/integration/baseline_images/plot_test/rectilinear.png and b/tests/integration/baseline_images/plot_test/rectilinear.png differ diff --git a/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png b/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png deleted file mode 100644 index d7cb667c..00000000 Binary files a/tests/integration/baseline_images/plot_test/unit_vector_autocorrelation.png and /dev/null differ diff --git a/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png b/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png index 1624bb8e..1571ca6f 100644 Binary files a/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png and b/tests/integration/baseline_images/plot_test/vibrational_amplitudes.png differ diff --git a/tests/integration/plot_test.py b/tests/integration/plot_test.py index ab4a7a19..ca2ecd8b 100644 --- a/tests/integration/plot_test.py +++ b/tests/integration/plot_test.py @@ -61,7 +61,9 @@ def test_jumps_3d_animation(vasp_jumps): plots.jumps_3d_animation(jumps=vasp_jumps, t_start=1000, t_stop=1001) -@image_comparison2(baseline_images=['rdf1', 'rdf2', 'rdf3']) +@image_comparison2(baseline_images=[ + 'radial_distribution_1', 'radial_distribution_2', 'radial_distribution_3' +]) def test_rdf(vasp_rdf_data): assert len(vasp_rdf_data) == 3 for rdfs in vasp_rdf_data.values(): @@ -75,12 +77,12 @@ def test_shape(vasp_shape_data): plots.shape(shape) -@image_comparison2(baseline_images=['msd']) +@image_comparison2(baseline_images=['msd_per_element']) def test_msd_per_element(vasp_traj): plots.msd_per_element(trajectory=vasp_traj[-500:]) -@image_comparison2(baseline_images=['path_energy']) +@image_comparison2(baseline_images=['energy_along_path']) def test_path_energy(vasp_path): structure = load_known_material('argyrodite') plots.energy_along_path(path=vasp_path, structure=structure) @@ -101,6 +103,6 @@ def test_bond_length_distribution(vasp_orientations): vasp_orientations.plot_bond_length_distribution(bins=1000) -@image_comparison2(baseline_images=['unit_vector_autocorrelation']) +@image_comparison2(baseline_images=['autocorrelation']) def test_unit_vector_autocorrelation(vasp_orientations): vasp_orientations.plot_autocorrelation() diff --git a/tests/plots_tests.py b/tests/plots_tests.py new file mode 100644 index 00000000..45306ae5 --- /dev/null +++ b/tests/plots_tests.py @@ -0,0 +1,8 @@ +def test_matplotlib_imports(): + from gemdat.plots import matplotlib + assert matplotlib.__all__ + + +def test_plotly_imports(): + from gemdat.plots import plotly + assert plotly.__all__