diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b2681f8..a95e4e18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ --- repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace args: [--markdown-linebreak-ext=md] @@ -20,11 +20,11 @@ repos: hooks: - id: nbcheckorder - repo: https://github.com/myint/docformatter - rev: v1.7.5 + rev: 06907d0 hooks: - id: docformatter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.2 + rev: v0.6.9 hooks: - id: ruff args: [--fix] @@ -32,7 +32,7 @@ repos: - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.1 + rev: v1.11.2 hooks: - id: mypy additional_dependencies: [matplotlib, MDAnalysis, numpy, pymatgen, rich, scikit-image, scipy] diff --git a/src/gemdat/plots/_shared.py b/src/gemdat/plots/_shared.py index 13ccbd15..ec3f427b 100644 --- a/src/gemdat/plots/_shared.py +++ b/src/gemdat/plots/_shared.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np @@ -84,3 +85,65 @@ def hex2rgba(hex_color: str, *, opacity: float = 1) -> str: b = int(hex_color[5:7], 16) return f'rgba({r},{g},{b},{opacity})' + + +@dataclass +class VibrationalAmplitudeHist: + amplitudes: np.ndarray + counts: np.ndarray + std: np.ndarray + + @property + def min_amp(self): + return self.amplitudes.min() + + @property + def max_amp(self): + return self.amplitudes.max() + + @property + def width(self): + bins = len(self.amplitudes) + return (self.max_amp - self.min_amp) / bins + + @property + def offset(self): + return self.width / 2 + + @property + def centers(self): + return self.amplitudes + self.offset + + @property + def dataframe(self): + return pd.DataFrame( + data=zip(self.centers, self.counts, self.std), columns=['center', 'count', 'std'] + ) + + +def _get_vibrational_amplitudes_hist( + *, trajectories: list[Trajectory], bins: int +) -> VibrationalAmplitudeHist: + """Calculate vabrational amplitudes histogram. + + Helper for `vibrational_amplitudes`. + """ + metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories] + + max_amp = max(max(metric) for metric in metrics) + min_amp = min(min(metric) for metric in metrics) + + max_amp = max(abs(min_amp), max_amp) + min_amp = -max_amp + + data = [] + + for metric in metrics: + data.append(np.histogram(metric, bins=bins, range=(min_amp, max_amp), density=True)[0]) + + amplitudes = np.linspace(min_amp, max_amp, bins, endpoint=False) + + mean = np.mean(data, axis=0) + std = np.std(data, axis=0) + + return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std) diff --git a/src/gemdat/plots/matplotlib/_autocorrelation.py b/src/gemdat/plots/matplotlib/_autocorrelation.py index 5feb540f..34f59fae 100644 --- a/src/gemdat/plots/matplotlib/_autocorrelation.py +++ b/src/gemdat/plots/matplotlib/_autocorrelation.py @@ -6,6 +6,8 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure + from gemdat.orientations import Orientations @@ -14,7 +16,7 @@ def autocorrelation( orientations: Orientations, show_traces: bool = True, show_shaded: bool = True, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot the autocorrelation function of the unit vectors series. Parameters diff --git a/src/gemdat/plots/matplotlib/_bond_length_distribution.py b/src/gemdat/plots/matplotlib/_bond_length_distribution.py index f48b8767..66fe544b 100644 --- a/src/gemdat/plots/matplotlib/_bond_length_distribution.py +++ b/src/gemdat/plots/matplotlib/_bond_length_distribution.py @@ -7,10 +7,14 @@ from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram if TYPE_CHECKING: + import matplotlib.figure + from gemdat.orientations import Orientations -def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> plt.Figure: +def bond_length_distribution( + *, orientations: Orientations, bins: int = 50 +) -> matplotlib.figure.Figure: """Plot the bond length probability distribution. Parameters diff --git a/src/gemdat/plots/matplotlib/_collective_jumps.py b/src/gemdat/plots/matplotlib/_collective_jumps.py index cebb8418..ca0da25b 100644 --- a/src/gemdat/plots/matplotlib/_collective_jumps.py +++ b/src/gemdat/plots/matplotlib/_collective_jumps.py @@ -5,10 +5,12 @@ import matplotlib.pyplot as plt if TYPE_CHECKING: + import matplotlib.figure + from gemdat import Jumps -def collective_jumps(*, jumps: Jumps) -> plt.Figure: +def collective_jumps(*, jumps: Jumps) -> matplotlib.figure.Figure: """Plot collective jumps per jump-type combination. Parameters diff --git a/src/gemdat/plots/matplotlib/_displacement_histogram.py b/src/gemdat/plots/matplotlib/_displacement_histogram.py index 56921b28..b3189d3d 100644 --- a/src/gemdat/plots/matplotlib/_displacement_histogram.py +++ b/src/gemdat/plots/matplotlib/_displacement_histogram.py @@ -5,10 +5,12 @@ import matplotlib.pyplot as plt if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory -def displacement_histogram(trajectory: Trajectory) -> plt.Figure: +def displacement_histogram(trajectory: Trajectory) -> matplotlib.figure.Figure: """Plot histogram of total displacement at final timestep. Parameters diff --git a/src/gemdat/plots/matplotlib/_displacement_per_atom.py b/src/gemdat/plots/matplotlib/_displacement_per_atom.py index 111a8ff1..3bb07aa3 100644 --- a/src/gemdat/plots/matplotlib/_displacement_per_atom.py +++ b/src/gemdat/plots/matplotlib/_displacement_per_atom.py @@ -5,10 +5,12 @@ import matplotlib.pyplot as plt if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory -def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure: +def displacement_per_atom(*, trajectory: Trajectory) -> matplotlib.figure.Figure: """Plot displacement per atom. Parameters diff --git a/src/gemdat/plots/matplotlib/_displacement_per_element.py b/src/gemdat/plots/matplotlib/_displacement_per_element.py index 83f9955c..3a8fda7d 100644 --- a/src/gemdat/plots/matplotlib/_displacement_per_element.py +++ b/src/gemdat/plots/matplotlib/_displacement_per_element.py @@ -7,10 +7,12 @@ from gemdat.plots._shared import _mean_displacements_per_element if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory -def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure: +def displacement_per_element(*, trajectory: Trajectory) -> matplotlib.figure.Figure: """Plot displacement per element. Parameters diff --git a/src/gemdat/plots/matplotlib/_energy_along_path.py b/src/gemdat/plots/matplotlib/_energy_along_path.py index d9a0cd1b..34c968a0 100644 --- a/src/gemdat/plots/matplotlib/_energy_along_path.py +++ b/src/gemdat/plots/matplotlib/_energy_along_path.py @@ -6,6 +6,8 @@ from pymatgen.core import Structure if TYPE_CHECKING: + import matplotlib.figure + from gemdat.path import Pathway @@ -14,7 +16,7 @@ def energy_along_path( *, structure: Structure | None = None, other_paths: list[Pathway] | None = None, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot energy along specified path. Parameters diff --git a/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py index 09551aae..3f69fbb7 100644 --- a/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py +++ b/src/gemdat/plots/matplotlib/_frequency_vs_occurence.py @@ -6,10 +6,12 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory -def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure: +def frequency_vs_occurence(*, trajectory: Trajectory) -> matplotlib.figure.Figure: """Plot attempt frequency vs occurence. Parameters diff --git a/src/gemdat/plots/matplotlib/_jumps_3d.py b/src/gemdat/plots/matplotlib/_jumps_3d.py index 43f9f911..20b78609 100644 --- a/src/gemdat/plots/matplotlib/_jumps_3d.py +++ b/src/gemdat/plots/matplotlib/_jumps_3d.py @@ -7,10 +7,12 @@ from pymatgen.electronic_structure import plotter if TYPE_CHECKING: + import matplotlib.figure + from gemdat import Jumps -def jumps_3d(*, jumps: Jumps) -> plt.Figure: +def jumps_3d(*, jumps: Jumps) -> matplotlib.figure.Figure: """Plot jumps in 3D. Parameters diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_distance.py b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py index ff05c1ea..e1172972 100644 --- a/src/gemdat/plots/matplotlib/_jumps_vs_distance.py +++ b/src/gemdat/plots/matplotlib/_jumps_vs_distance.py @@ -6,6 +6,8 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure + from gemdat import Jumps @@ -13,7 +15,7 @@ def jumps_vs_distance( *, jumps: Jumps, jump_res: float = 0.1, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot jumps vs. distance histogram. Parameters diff --git a/src/gemdat/plots/matplotlib/_jumps_vs_time.py b/src/gemdat/plots/matplotlib/_jumps_vs_time.py index 75888d12..c8ad5d5e 100644 --- a/src/gemdat/plots/matplotlib/_jumps_vs_time.py +++ b/src/gemdat/plots/matplotlib/_jumps_vs_time.py @@ -5,10 +5,12 @@ import matplotlib.pyplot as plt if TYPE_CHECKING: + import matplotlib.figure + from gemdat import Jumps -def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure: +def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> matplotlib.figure.Figure: """Plot jumps vs. time histogram. Parameters diff --git a/src/gemdat/plots/matplotlib/_msd_per_element.py b/src/gemdat/plots/matplotlib/_msd_per_element.py index 05b29c2f..99ad42dc 100644 --- a/src/gemdat/plots/matplotlib/_msd_per_element.py +++ b/src/gemdat/plots/matplotlib/_msd_per_element.py @@ -6,6 +6,8 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory @@ -14,7 +16,7 @@ def msd_per_element( trajectory: Trajectory, show_traces: bool = True, show_shaded: bool = True, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot mean squared displacement per element. Parameters diff --git a/src/gemdat/plots/matplotlib/_radial_distribution.py b/src/gemdat/plots/matplotlib/_radial_distribution.py index 7de79559..b8a4c7c6 100644 --- a/src/gemdat/plots/matplotlib/_radial_distribution.py +++ b/src/gemdat/plots/matplotlib/_radial_distribution.py @@ -5,10 +5,12 @@ import matplotlib.pyplot as plt if TYPE_CHECKING: + import matplotlib.figure + from gemdat.rdf import RDFData -def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure: +def radial_distribution(rdfs: Iterable[RDFData]) -> matplotlib.figure.Figure: """Plot radial distribution function. Parameters diff --git a/src/gemdat/plots/matplotlib/_rectilinear.py b/src/gemdat/plots/matplotlib/_rectilinear.py index 746d30f0..a5983176 100644 --- a/src/gemdat/plots/matplotlib/_rectilinear.py +++ b/src/gemdat/plots/matplotlib/_rectilinear.py @@ -6,6 +6,8 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure + from gemdat.orientations import Orientations @@ -14,7 +16,7 @@ def rectilinear( orientations: Orientations, shape: tuple[int, int] = (90, 360), normalize_histo: bool = True, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot a rectilinear projection of a spherical function. This function uses the transformed trajectory. diff --git a/src/gemdat/plots/matplotlib/_shape.py b/src/gemdat/plots/matplotlib/_shape.py index 665b8254..c8391ef3 100644 --- a/src/gemdat/plots/matplotlib/_shape.py +++ b/src/gemdat/plots/matplotlib/_shape.py @@ -7,6 +7,7 @@ import numpy as np if TYPE_CHECKING: + import matplotlib.figure from pymatgem.core import PeriodicSite from gemdat.shape import ShapeData @@ -16,7 +17,7 @@ def shape( shape: ShapeData, bins: int | Sequence[float] = 50, sites: Collection[PeriodicSite] | None = None, -) -> plt.Figure: +) -> matplotlib.figure.Figure: """Plot site cluster shapes. Parameters diff --git a/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py index da1df8bd..86d2a12c 100644 --- a/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py +++ b/src/gemdat/plots/matplotlib/_vibrational_amplitudes.py @@ -6,17 +6,27 @@ import numpy as np from scipy import stats +from .._shared import _get_vibrational_amplitudes_hist + if TYPE_CHECKING: + import matplotlib.figure + from gemdat.trajectory import Trajectory -def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure: +def vibrational_amplitudes( + *, trajectory: Trajectory, bins: int = 50, n_parts: int = 1 +) -> matplotlib.figure.Figure: """Plot histogram of vibrational amplitudes with fitted Gaussian. Parameters ---------- trajectory : Trajectory Input trajectory, i.e. for the diffusing atom + bins : int + Number of bins for the histogram + n_parts : int + Number of parts for error analysis Returns ------- @@ -25,10 +35,14 @@ def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure: """ metrics = trajectory.metrics() + trajectories = trajectory.split(n_parts) + + hist = _get_vibrational_amplitudes_hist(trajectories=trajectories, bins=bins) fig, ax = plt.subplots() - ax.hist(metrics.amplitudes(), bins=100, density=True) - x = np.linspace(-2, 2, 100) + plt.bar(hist.amplitudes + hist.offset, hist.counts, width=hist.width, yerr=hist.std) + + x = np.linspace(hist.min_amp, hist.max_amp, 100) y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude()) ax.plot(x, y_gauss, 'r') diff --git a/src/gemdat/plots/plotly/_vibrational_amplitudes.py b/src/gemdat/plots/plotly/_vibrational_amplitudes.py index dc7d7657..9fd2a215 100644 --- a/src/gemdat/plots/plotly/_vibrational_amplitudes.py +++ b/src/gemdat/plots/plotly/_vibrational_amplitudes.py @@ -3,11 +3,12 @@ from typing import TYPE_CHECKING import numpy as np -import pandas as pd import plotly.express as px import plotly.graph_objects as go from scipy import stats +from .._shared import _get_vibrational_amplitudes_hist + if TYPE_CHECKING: from gemdat.trajectory import Trajectory @@ -29,41 +30,19 @@ def vibrational_amplitudes( fig : plotly.graph_objects.Figure Output figure """ + metrics = trajectory.metrics() trajectories = trajectory.split(n_parts) - single_metrics = trajectory.metrics() - metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories] - - max_amp = max(max(metric) for metric in metrics) - min_amp = min(min(metric) for metric in metrics) - - max_amp = max(abs(min_amp), max_amp) - min_amp = -max_amp - - data = [] - - for metric in metrics: - data.append(np.histogram(metric, bins=bins, range=(min_amp, max_amp), density=True)[0]) - - df = pd.DataFrame(data=data) - - # offset to middle of bar - offset = (max_amp - min_amp) / (bins * 2) - - columns = np.linspace(min_amp + offset, max_amp + offset, bins, endpoint=False) - - mean = [df[col].mean() for col in df.columns] - std = [df[col].std() for col in df.columns] - df = pd.DataFrame(data=zip(columns, mean, std), columns=['amplitude', 'count', 'std']) + hist = _get_vibrational_amplitudes_hist(trajectories=trajectories, bins=bins) if n_parts == 1: - fig = px.bar(df, x='amplitude', y='count') + fig = px.bar(hist.dataframe, x='center', y='count') else: - fig = px.bar(df, x='amplitude', y='count', error_y='std') + fig = px.bar(hist.dataframe, x='center', y='count', error_y='std') - x = np.linspace(min_amp, max_amp, 100) - y_gauss = stats.norm.pdf(x, 0, single_metrics.vibration_amplitude()) + x = np.linspace(hist.min_amp, hist.max_amp, 100) + hist.offset + y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude()) fig.add_trace(go.Scatter(x=x, y=y_gauss, name='Fitted Gaussian')) fig.update_layout( diff --git a/tests/integration/baseline_images/plot_mpl_test/vibrational_amplitudes.png b/tests/integration/baseline_images/plot_mpl_test/vibrational_amplitudes.png index 1571ca6f..b61407f5 100644 Binary files a/tests/integration/baseline_images/plot_mpl_test/vibrational_amplitudes.png and b/tests/integration/baseline_images/plot_mpl_test/vibrational_amplitudes.png differ diff --git a/tests/integration/baseline_images/plot_plotly_test/vibrational_amplitudes.png b/tests/integration/baseline_images/plot_plotly_test/vibrational_amplitudes.png index e2c773d6..c55ebd45 100644 Binary files a/tests/integration/baseline_images/plot_plotly_test/vibrational_amplitudes.png and b/tests/integration/baseline_images/plot_plotly_test/vibrational_amplitudes.png differ