|
| 1 | +"""Plotting functions.""" |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import collections |
| 5 | +import logging |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +from matplotlib import pyplot as plt |
| 10 | +from matplotlib.colors import Normalize |
| 11 | +from pymatgen.electronic_structure.core import Spin |
| 12 | + |
| 13 | +from pymatgen.analysis.defects.ccd import HarmonicDefect |
| 14 | + |
| 15 | +__author__ = "Jimmy Shen" |
| 16 | +__copyright__ = "Copyright 2022, The Materials Project" |
| 17 | +__maintainer__ = "Jimmy Shen @jmmshn" |
| 18 | +__date__ = "July 2023" |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +def plot_optical_transitions( |
| 24 | + defect: HarmonicDefect, |
| 25 | + kpt_index: int = 0, |
| 26 | + band_window: int = 5, |
| 27 | + user_defect_band: tuple = tuple(), |
| 28 | + ijdirs=((0, 0), (1, 1), (2, 2)), |
| 29 | + shift_eig: dict[tuple, float] = None, |
| 30 | + x0: float = 0, |
| 31 | + x_width: float = 2, |
| 32 | + ax=None, |
| 33 | + cmap=None, |
| 34 | + norm=None, |
| 35 | +): |
| 36 | + """Plot the optical transitions from the defect state to all other states. |
| 37 | +
|
| 38 | + Only plot the transitions for a specific kpoint index. The arrows present the transitions |
| 39 | + between the defect state of interest and all other states. The color of the arrows |
| 40 | + indicate the magnitude of the matrix element (derivative of the wavefunction) for the |
| 41 | + transition. |
| 42 | +
|
| 43 | + Args: |
| 44 | + defect: |
| 45 | + The HarmonicDefect object, the `relaxed_bandstructure` attribute |
| 46 | + must be set since this contains the eigenvalues. |
| 47 | + Please see the `store_bandstructure` option in the constructor. |
| 48 | + kpt_index: |
| 49 | + The kpoint index to read the eigenvalues from. |
| 50 | + band_window: |
| 51 | + The number of bands above and below the defect state to include in the output. |
| 52 | + user_defect_band: |
| 53 | + (band, kpt, spin) tuple to specify the defect state. If not provided, |
| 54 | + the defect state will be determined automatically using the inverse |
| 55 | + participation ratio and the `kpt_index` argument. |
| 56 | + ijdirs: |
| 57 | + The cartesian direction of the WAVDER tensor to sum over for the plot. |
| 58 | + If not provided, all the absolute values of the matrix for all |
| 59 | + three diagonal entries will be summed. |
| 60 | + shift_eig: |
| 61 | + A dictionary of the format `(band, kpt, spin) -> float` to apply to the |
| 62 | + eigenvalues. This is useful for aligning the defect state with the |
| 63 | + valence or conduction band for plotting and schematic purposes. |
| 64 | + x0: |
| 65 | + The x coordinate of the center of the set of arrows and the eigenvalue plot. |
| 66 | + x_width: |
| 67 | + The width of the set of arrows and the eigenvalue plot. |
| 68 | + ax: |
| 69 | + The matplotlib axis object to plot on. |
| 70 | + cmap: |
| 71 | + The matplotlib color map to use for the color of the arrorws. |
| 72 | + norm: |
| 73 | + The matplotlib normalization to use for the color map of the arrows. |
| 74 | +
|
| 75 | + """ |
| 76 | + d_eigs = get_bs_eigenvalues( |
| 77 | + defect=defect, |
| 78 | + kpt_index=kpt_index, |
| 79 | + band_window=band_window, |
| 80 | + user_defect_band=user_defect_band, |
| 81 | + shift_eig=shift_eig, |
| 82 | + ) |
| 83 | + if user_defect_band: |
| 84 | + defect_band_index = user_defect_band[0] |
| 85 | + else: |
| 86 | + defect_band_index = next( |
| 87 | + filter(lambda x: x[1] == kpt_index, defect.defect_band) |
| 88 | + )[0] |
| 89 | + |
| 90 | + if ax is None: |
| 91 | + ax_ = plt.gca() |
| 92 | + else: # pragma: no cover |
| 93 | + ax_ = ax |
| 94 | + _plot_eigs( |
| 95 | + d_eigs, defect.relaxed_bandstructure.efermi, ax=ax_, x0=x0, x_width=x_width |
| 96 | + ) |
| 97 | + me_plot_data, cmap, norm = _plot_matrix_elements( |
| 98 | + defect.waveder.cder, |
| 99 | + d_eigs, |
| 100 | + defect_band_index=defect_band_index, |
| 101 | + ijdirs=ijdirs, |
| 102 | + ax=ax_, |
| 103 | + x0=x0, |
| 104 | + x_width=x_width, |
| 105 | + cmap=cmap, |
| 106 | + norm=norm, |
| 107 | + ) |
| 108 | + return _get_dataframe(d_eigs=d_eigs, me_plot_data=me_plot_data), cmap, norm |
| 109 | + |
| 110 | + |
| 111 | +def get_bs_eigenvalues( |
| 112 | + defect: HarmonicDefect, |
| 113 | + kpt_index: int = 0, |
| 114 | + band_window: int = 5, |
| 115 | + user_defect_band: tuple = tuple(), |
| 116 | + shift_eig: dict[tuple, float] = None, |
| 117 | +) -> dict[tuple, float]: |
| 118 | + """Read the eigenvalues from `HarmonicDefect.relaxed_bandstructure`. |
| 119 | +
|
| 120 | + Args: |
| 121 | + defect: |
| 122 | + The HarmonicDefect object, the `relaxed_bandstructure` attribute |
| 123 | + must be set since this contains the eigenvalues. |
| 124 | + Please see the `store_bandstructure` option in the constructor. |
| 125 | + kpt_index: |
| 126 | + The kpoint index to read the eigenvalues from. |
| 127 | + band_window: |
| 128 | + The number of bands above and below the Fermi level to include. |
| 129 | + user_defect_band: |
| 130 | + (band, kpt, spin) tuple to specify the defect state. If not provided, |
| 131 | + the defect state will be determined automatically using the inverse |
| 132 | + participation ratio. |
| 133 | + The user provided kpoint index here will overwrite the kpt_index argument. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + Dictionary of the format: (iband, ikpt, ispin) -> eigenvalue |
| 137 | + """ |
| 138 | + |
| 139 | + if defect.relaxed_bandstructure is None: # pragma: no cover |
| 140 | + raise ValueError("The defect object does not have a band structure.") |
| 141 | + |
| 142 | + if user_defect_band: |
| 143 | + def_indices = user_defect_band |
| 144 | + else: |
| 145 | + def_indices = next(filter(lambda x: x[1] == kpt_index, defect.defect_band)) |
| 146 | + |
| 147 | + band_index, kpt_index, spin_index = def_indices |
| 148 | + spin_key = Spin.up if spin_index == 0 else Spin.down |
| 149 | + output: dict[tuple, float] = dict() |
| 150 | + shift_dict: dict = collections.defaultdict(lambda: 0.0) |
| 151 | + if shift_eig is not None: |
| 152 | + shift_dict.update(shift_eig) |
| 153 | + for ib in range(band_index - band_window, band_index + band_window + 1): |
| 154 | + output[(ib, kpt_index, spin_index)] = ( |
| 155 | + defect.relaxed_bandstructure.bands[spin_key][ib, kpt_index] |
| 156 | + + shift_dict[(ib, kpt_index, spin_index)] |
| 157 | + ) |
| 158 | + return output |
| 159 | + |
| 160 | + |
| 161 | +def _plot_eigs( |
| 162 | + d_eigs: dict[tuple, float], |
| 163 | + e_fermi=None, |
| 164 | + ax=None, |
| 165 | + x0: float = 0.0, |
| 166 | + x_width: float = 0.3, |
| 167 | + **kwargs, |
| 168 | +) -> None: |
| 169 | + """Plot the eigenvalues.""" |
| 170 | + if ax is None: # pragma: no cover |
| 171 | + ax = plt.gca() |
| 172 | + |
| 173 | + # Use current color scheme |
| 174 | + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] |
| 175 | + collections.defaultdict(list) |
| 176 | + eigenvalues = np.array(list(d_eigs.values())) |
| 177 | + if e_fermi is None: # pragma: no cover |
| 178 | + e_fermi = -np.inf |
| 179 | + |
| 180 | + eigs_ = eigenvalues[eigenvalues <= e_fermi] |
| 181 | + ax.hlines( |
| 182 | + eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[0], **kwargs |
| 183 | + ) |
| 184 | + eigs_ = eigenvalues[eigenvalues > e_fermi] |
| 185 | + ax.hlines( |
| 186 | + eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[1], **kwargs |
| 187 | + ) |
| 188 | + |
| 189 | + # turn off x-aixs |
| 190 | + ax.get_xaxis().set_visible(False) |
| 191 | + |
| 192 | + |
| 193 | +def _plot_matrix_elements( |
| 194 | + cder, |
| 195 | + d_eig, |
| 196 | + defect_band_index, |
| 197 | + ijdirs=((0, 0), (1, 1), (2, 2)), |
| 198 | + ax=None, |
| 199 | + x0=0, |
| 200 | + x_width=0.6, |
| 201 | + arrow_width=0.1, |
| 202 | + cmap=None, |
| 203 | + norm=None, |
| 204 | +): |
| 205 | + """Plot arrow for the transition from the defect state to all other states. |
| 206 | +
|
| 207 | + Args: |
| 208 | + cder: |
| 209 | + The matrix element (derivative of the wavefunction) for the defect state. |
| 210 | + d_eig: |
| 211 | + The dictionary of eigenvalues for the defect state. In the format of |
| 212 | + (iband, ikpt, ispin) -> eigenvalue |
| 213 | + defect_band_index: |
| 214 | + The band index of the defect state. |
| 215 | + ax: |
| 216 | + The matplotlib axis object to plot on. |
| 217 | + x0: |
| 218 | + The x coordinate of the center of the set of arrows. |
| 219 | + x_width: |
| 220 | + The width of the set of arrows. |
| 221 | + arrow_width: |
| 222 | + The width of the arrow. |
| 223 | + cmap: |
| 224 | + The matplotlib color map to use. |
| 225 | + norm: |
| 226 | + The matplotlib normalization to use for the color map. |
| 227 | + ijdirs: |
| 228 | + The cartesian direction of the WAVDER tensor to sum over for the plot. |
| 229 | + If not provided, all the absolute values of the matrix for all |
| 230 | + three diagonal entries will be summed. |
| 231 | + """ |
| 232 | + if ax is None: # pragma: no cover |
| 233 | + ax = plt.gca() |
| 234 | + ax.set_aspect("equal") |
| 235 | + jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys())) |
| 236 | + y0 = d_eig[jb, jkpt, jspin] |
| 237 | + plot_data = [] |
| 238 | + for (ib, ik, ispin), eig in d_eig.items(): |
| 239 | + A = 0 |
| 240 | + for idir, jdir in ijdirs: |
| 241 | + A += np.abs( |
| 242 | + cder[ib, jb, ik, ispin, idir] |
| 243 | + * np.conjugate(cder[ib, jb, ik, ispin, jdir]) |
| 244 | + ) |
| 245 | + plot_data.append((jb, ib, eig, A)) |
| 246 | + |
| 247 | + if cmap is None: |
| 248 | + cmap = plt.get_cmap("viridis") |
| 249 | + |
| 250 | + # get the range of A values |
| 251 | + if norm is None: |
| 252 | + A_min, A_max = ( |
| 253 | + min(plot_data, key=lambda x: x[3])[3], |
| 254 | + max(plot_data, key=lambda x: x[3])[3], |
| 255 | + ) |
| 256 | + norm = Normalize(vmin=A_min, vmax=A_max) |
| 257 | + |
| 258 | + n_arrows = len(plot_data) |
| 259 | + x_step = x_width / n_arrows |
| 260 | + x = x0 - x_width / 2 + x_step / 2 |
| 261 | + for ib, jb, eig, A in plot_data: |
| 262 | + ax.arrow( |
| 263 | + x=x, |
| 264 | + y=y0, |
| 265 | + dx=0, |
| 266 | + dy=eig - y0, |
| 267 | + width=arrow_width, |
| 268 | + length_includes_head=True, |
| 269 | + head_width=arrow_width * 2, |
| 270 | + head_length=arrow_width * 2, |
| 271 | + color=cmap(norm(A)), |
| 272 | + zorder=20, |
| 273 | + ) |
| 274 | + x += x_step |
| 275 | + return plot_data, cmap, norm |
| 276 | + |
| 277 | + |
| 278 | +def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame: |
| 279 | + """Convert the eigenvalue and matrix element data into a pandas dataframe.""" |
| 280 | + _, ikpt, ispin = next(iter(d_eigs.keys())) |
| 281 | + df = pd.DataFrame( |
| 282 | + me_plot_data, |
| 283 | + columns=["ib", "jb", "eig", "M.E."], |
| 284 | + ) |
| 285 | + df["kpt"] = ikpt |
| 286 | + df["spin"] = ispin |
| 287 | + return df |
0 commit comments