Skip to content

Fix vibration amplitudes plot #336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
Expand All @@ -20,19 +20,19 @@ repos:
hooks:
- id: nbcheckorder
- repo: https://github.com/myint/docformatter
rev: v1.7.5
rev: 06907d0
hooks:
- id: docformatter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
rev: v0.6.9
hooks:
- id: ruff
args: [--fix]
types_or: [python, pyi, jupyter]
- id: ruff-format
types_or: [python, pyi, jupyter]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [matplotlib, MDAnalysis, numpy, pymatgen, rich, scikit-image, scipy]
63 changes: 63 additions & 0 deletions src/gemdat/plots/_shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -84,3 +85,65 @@ def hex2rgba(hex_color: str, *, opacity: float = 1) -> str:
b = int(hex_color[5:7], 16)

return f'rgba({r},{g},{b},{opacity})'


@dataclass
class VibrationalAmplitudeHist:
amplitudes: np.ndarray
counts: np.ndarray
std: np.ndarray

@property
def min_amp(self):
return self.amplitudes.min()

@property
def max_amp(self):
return self.amplitudes.max()

@property
def width(self):
bins = len(self.amplitudes)
return (self.max_amp - self.min_amp) / bins

@property
def offset(self):
return self.width / 2

@property
def centers(self):
return self.amplitudes + self.offset

@property
def dataframe(self):
return pd.DataFrame(
data=zip(self.centers, self.counts, self.std), columns=['center', 'count', 'std']
)


def _get_vibrational_amplitudes_hist(
*, trajectories: list[Trajectory], bins: int
) -> VibrationalAmplitudeHist:
"""Calculate vabrational amplitudes histogram.

Helper for `vibrational_amplitudes`.
"""
metrics = [trajectory.metrics().amplitudes() for trajectory in trajectories]

max_amp = max(max(metric) for metric in metrics)
min_amp = min(min(metric) for metric in metrics)

max_amp = max(abs(min_amp), max_amp)
min_amp = -max_amp

data = []

for metric in metrics:
data.append(np.histogram(metric, bins=bins, range=(min_amp, max_amp), density=True)[0])

amplitudes = np.linspace(min_amp, max_amp, bins, endpoint=False)

mean = np.mean(data, axis=0)
std = np.std(data, axis=0)

return VibrationalAmplitudeHist(amplitudes=amplitudes, counts=mean, std=std)
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_autocorrelation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.orientations import Orientations


Expand All @@ -14,7 +16,7 @@ def autocorrelation(
orientations: Orientations,
show_traces: bool = True,
show_shaded: bool = True,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot the autocorrelation function of the unit vectors series.

Parameters
Expand Down
6 changes: 5 additions & 1 deletion src/gemdat/plots/matplotlib/_bond_length_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
from .._shared import _fit_skewnorm_to_hist, _orientations_to_histogram

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.orientations import Orientations


def bond_length_distribution(*, orientations: Orientations, bins: int = 50) -> plt.Figure:
def bond_length_distribution(
*, orientations: Orientations, bins: int = 50
) -> matplotlib.figure.Figure:
"""Plot the bond length probability distribution.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_collective_jumps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import matplotlib.pyplot as plt

if TYPE_CHECKING:
import matplotlib.figure

from gemdat import Jumps


def collective_jumps(*, jumps: Jumps) -> plt.Figure:
def collective_jumps(*, jumps: Jumps) -> matplotlib.figure.Figure:
"""Plot collective jumps per jump-type combination.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import matplotlib.pyplot as plt

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


def displacement_histogram(trajectory: Trajectory) -> plt.Figure:
def displacement_histogram(trajectory: Trajectory) -> matplotlib.figure.Figure:
"""Plot histogram of total displacement at final timestep.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import matplotlib.pyplot as plt

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


def displacement_per_atom(*, trajectory: Trajectory) -> plt.Figure:
def displacement_per_atom(*, trajectory: Trajectory) -> matplotlib.figure.Figure:
"""Plot displacement per atom.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_displacement_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from gemdat.plots._shared import _mean_displacements_per_element

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


def displacement_per_element(*, trajectory: Trajectory) -> plt.Figure:
def displacement_per_element(*, trajectory: Trajectory) -> matplotlib.figure.Figure:
"""Plot displacement per element.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_energy_along_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pymatgen.core import Structure

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.path import Pathway


Expand All @@ -14,7 +16,7 @@ def energy_along_path(
*,
structure: Structure | None = None,
other_paths: list[Pathway] | None = None,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot energy along specified path.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_frequency_vs_occurence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


def frequency_vs_occurence(*, trajectory: Trajectory) -> plt.Figure:
def frequency_vs_occurence(*, trajectory: Trajectory) -> matplotlib.figure.Figure:
"""Plot attempt frequency vs occurence.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_jumps_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from pymatgen.electronic_structure import plotter

if TYPE_CHECKING:
import matplotlib.figure

from gemdat import Jumps


def jumps_3d(*, jumps: Jumps) -> plt.Figure:
def jumps_3d(*, jumps: Jumps) -> matplotlib.figure.Figure:
"""Plot jumps in 3D.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_jumps_vs_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure

from gemdat import Jumps


def jumps_vs_distance(
*,
jumps: Jumps,
jump_res: float = 0.1,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot jumps vs. distance histogram.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_jumps_vs_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import matplotlib.pyplot as plt

if TYPE_CHECKING:
import matplotlib.figure

from gemdat import Jumps


def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> plt.Figure:
def jumps_vs_time(*, jumps: Jumps, binsize: int = 500) -> matplotlib.figure.Figure:
"""Plot jumps vs. time histogram.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_msd_per_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


Expand All @@ -14,7 +16,7 @@ def msd_per_element(
trajectory: Trajectory,
show_traces: bool = True,
show_shaded: bool = True,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot mean squared displacement per element.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_radial_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import matplotlib.pyplot as plt

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.rdf import RDFData


def radial_distribution(rdfs: Iterable[RDFData]) -> plt.Figure:
def radial_distribution(rdfs: Iterable[RDFData]) -> matplotlib.figure.Figure:
"""Plot radial distribution function.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/gemdat/plots/matplotlib/_rectilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.orientations import Orientations


Expand All @@ -14,7 +16,7 @@ def rectilinear(
orientations: Orientations,
shape: tuple[int, int] = (90, 360),
normalize_histo: bool = True,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot a rectilinear projection of a spherical function. This function
uses the transformed trajectory.

Expand Down
3 changes: 2 additions & 1 deletion src/gemdat/plots/matplotlib/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

if TYPE_CHECKING:
import matplotlib.figure
from pymatgem.core import PeriodicSite

from gemdat.shape import ShapeData
Expand All @@ -16,7 +17,7 @@ def shape(
shape: ShapeData,
bins: int | Sequence[float] = 50,
sites: Collection[PeriodicSite] | None = None,
) -> plt.Figure:
) -> matplotlib.figure.Figure:
"""Plot site cluster shapes.

Parameters
Expand Down
20 changes: 17 additions & 3 deletions src/gemdat/plots/matplotlib/_vibrational_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,27 @@
import numpy as np
from scipy import stats

from .._shared import _get_vibrational_amplitudes_hist

if TYPE_CHECKING:
import matplotlib.figure

from gemdat.trajectory import Trajectory


def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
def vibrational_amplitudes(
*, trajectory: Trajectory, bins: int = 50, n_parts: int = 1
) -> matplotlib.figure.Figure:
"""Plot histogram of vibrational amplitudes with fitted Gaussian.

Parameters
----------
trajectory : Trajectory
Input trajectory, i.e. for the diffusing atom
bins : int
Number of bins for the histogram
n_parts : int
Number of parts for error analysis

Returns
-------
Expand All @@ -25,10 +35,14 @@ def vibrational_amplitudes(*, trajectory: Trajectory) -> plt.Figure:
"""
metrics = trajectory.metrics()

trajectories = trajectory.split(n_parts)

hist = _get_vibrational_amplitudes_hist(trajectories=trajectories, bins=bins)
fig, ax = plt.subplots()
ax.hist(metrics.amplitudes(), bins=100, density=True)

x = np.linspace(-2, 2, 100)
plt.bar(hist.amplitudes + hist.offset, hist.counts, width=hist.width, yerr=hist.std)

x = np.linspace(hist.min_amp, hist.max_amp, 100)
y_gauss = stats.norm.pdf(x, 0, metrics.vibration_amplitude())
ax.plot(x, y_gauss, 'r')

Expand Down
Loading
Loading