Skip to content

Commit dbfedf0

Browse files
authored
Optical Transition Matrix Element Plotting (#137)
* deps * deps * python vers * teset * tests * cleaner strict * cleaner strict * wf * plotting plotting plotting plotting plotting plotting * plotting * plotting * notebook notebook notebook notebook/test * notebook/test
1 parent fd1cda9 commit dbfedf0

File tree

5 files changed

+357
-4
lines changed

5 files changed

+357
-4
lines changed

docs/source/content/photo-conduct.ipynb

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@
6969
"outputs": [],
7070
"source": [
7171
"dir0 = TEST_FILES / \"ccd_0_-1\" / \"optics\"\n",
72-
"hd0 = HarmonicDefect.from_directories(directories=[dir0])\n",
72+
"hd0 = HarmonicDefect.from_directories(directories=[dir0], store_bandstructure=True)\n",
73+
"# Note the `store_bandstructure=True` argument is required for the matrix element plotting later in the notebook.\n",
74+
"# but not required for the dielectric function calculation.\n",
7375
"print(f\"The defect band is {hd0.defect_band}\")\n",
7476
"print(f\"The vibrational frequency is omega={hd0.omega} in this case is gibberish.\")"
7577
]
@@ -107,8 +109,8 @@
107109
" if spin == Spin.up:\n",
108110
" return 0\n",
109111
" return 1\n",
110-
" occ = vr.eigenvalues[Spin.up][0, :, 1]\n",
111-
" fermi_idx = bisect.bisect_left(occ, -0.5, key=lambda x: -x) \n",
112+
" occ = vr.eigenvalues[Spin.up][0, :, 1] * -1\n",
113+
" fermi_idx = bisect.bisect_left(occ, -0.5) \n",
112114
" output = collections.defaultdict(list)\n",
113115
" for k, spin_eigs in vr.eigenvalues.items():\n",
114116
" spin_idx = _get_spin_idx(k)\n",
@@ -195,6 +197,52 @@
195197
"\n",
196198
"Of course for a complete picture of photoconductivity, the Frank-Condon type ofr vibrational state transition should also be considered, but we are already pushing the limits of what is acceptable in the independent-particle picture so we will leave that for another time.\n"
197199
]
200+
},
201+
{
202+
"cell_type": "markdown",
203+
"id": "1fd5f0b0",
204+
"metadata": {},
205+
"source": [
206+
"## Dipole Matrix Elements\n",
207+
"\n",
208+
"We can also check the dipole matrix elements for the (VBM)→(defect) and (defect)→(CBM) transitions explicitly by calling the `plot_optical_transitions` method as shown below.\n",
209+
"The function returns a summary `pandas.DataFrame` object with the dipole matrix elements as well as the `ListedColormap` and `Normalize` objects for plotting the colorbar. These objects can then be passed to other instances of the plotting function to ensure that the colorbar is consistent."
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": null,
215+
"id": "8e00ac9f",
216+
"metadata": {},
217+
"outputs": [],
218+
"source": [
219+
"from pymatgen.analysis.defects.plotting.optical import plot_optical_transitions\n",
220+
"import matplotlib as mpl\n",
221+
"fig, ax = plt.subplots()\n",
222+
"cm_ax = fig.add_axes([0.8,0.1,0.02,0.8])\n",
223+
"df_k0, cmap, norm = plot_optical_transitions(hd0, kpt_index=1, band_window=5, x0=3, ax=ax)\n",
224+
"df_k1, _, _ = plot_optical_transitions(hd0, kpt_index=0, band_window=5, x0=0, ax=ax, cmap=cmap, norm=norm)\n",
225+
"mpl.colorbar.ColorbarBase(cm_ax,cmap=cmap,norm=norm,orientation='vertical')\n",
226+
"ax.set_ylabel(\"Energy (eV)\")"
227+
]
228+
},
229+
{
230+
"cell_type": "markdown",
231+
"id": "43f9e8d0",
232+
"metadata": {},
233+
"source": [
234+
"The `DataFrame` object containing the dipole matrix elements can also be examined directly."
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"id": "d44ad8ef",
241+
"metadata": {},
242+
"outputs": [],
243+
"source": [
244+
"df_k0"
245+
]
198246
}
199247
],
200248
"metadata": {
@@ -208,7 +256,7 @@
208256
"name": "python",
209257
"nbconvert_exporter": "python",
210258
"pygments_lexer": "ipython3",
211-
"version": "3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]"
259+
"version": "3.9.16"
212260
}
213261
},
214262
"nbformat": 4,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Finite size corrections for defects."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Plotting functions."""
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Plotting functions."""
2+
from __future__ import annotations
3+
4+
import collections
5+
import logging
6+
7+
import numpy as np
8+
import pandas as pd
9+
from matplotlib import pyplot as plt
10+
from matplotlib.colors import Normalize
11+
from pymatgen.electronic_structure.core import Spin
12+
13+
from pymatgen.analysis.defects.ccd import HarmonicDefect
14+
15+
__author__ = "Jimmy Shen"
16+
__copyright__ = "Copyright 2022, The Materials Project"
17+
__maintainer__ = "Jimmy Shen @jmmshn"
18+
__date__ = "July 2023"
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def plot_optical_transitions(
24+
defect: HarmonicDefect,
25+
kpt_index: int = 0,
26+
band_window: int = 5,
27+
user_defect_band: tuple = tuple(),
28+
ijdirs=((0, 0), (1, 1), (2, 2)),
29+
shift_eig: dict[tuple, float] = None,
30+
x0: float = 0,
31+
x_width: float = 2,
32+
ax=None,
33+
cmap=None,
34+
norm=None,
35+
):
36+
"""Plot the optical transitions from the defect state to all other states.
37+
38+
Only plot the transitions for a specific kpoint index. The arrows present the transitions
39+
between the defect state of interest and all other states. The color of the arrows
40+
indicate the magnitude of the matrix element (derivative of the wavefunction) for the
41+
transition.
42+
43+
Args:
44+
defect:
45+
The HarmonicDefect object, the `relaxed_bandstructure` attribute
46+
must be set since this contains the eigenvalues.
47+
Please see the `store_bandstructure` option in the constructor.
48+
kpt_index:
49+
The kpoint index to read the eigenvalues from.
50+
band_window:
51+
The number of bands above and below the defect state to include in the output.
52+
user_defect_band:
53+
(band, kpt, spin) tuple to specify the defect state. If not provided,
54+
the defect state will be determined automatically using the inverse
55+
participation ratio and the `kpt_index` argument.
56+
ijdirs:
57+
The cartesian direction of the WAVDER tensor to sum over for the plot.
58+
If not provided, all the absolute values of the matrix for all
59+
three diagonal entries will be summed.
60+
shift_eig:
61+
A dictionary of the format `(band, kpt, spin) -> float` to apply to the
62+
eigenvalues. This is useful for aligning the defect state with the
63+
valence or conduction band for plotting and schematic purposes.
64+
x0:
65+
The x coordinate of the center of the set of arrows and the eigenvalue plot.
66+
x_width:
67+
The width of the set of arrows and the eigenvalue plot.
68+
ax:
69+
The matplotlib axis object to plot on.
70+
cmap:
71+
The matplotlib color map to use for the color of the arrorws.
72+
norm:
73+
The matplotlib normalization to use for the color map of the arrows.
74+
75+
"""
76+
d_eigs = get_bs_eigenvalues(
77+
defect=defect,
78+
kpt_index=kpt_index,
79+
band_window=band_window,
80+
user_defect_band=user_defect_band,
81+
shift_eig=shift_eig,
82+
)
83+
if user_defect_band:
84+
defect_band_index = user_defect_band[0]
85+
else:
86+
defect_band_index = next(
87+
filter(lambda x: x[1] == kpt_index, defect.defect_band)
88+
)[0]
89+
90+
if ax is None:
91+
ax_ = plt.gca()
92+
else: # pragma: no cover
93+
ax_ = ax
94+
_plot_eigs(
95+
d_eigs, defect.relaxed_bandstructure.efermi, ax=ax_, x0=x0, x_width=x_width
96+
)
97+
me_plot_data, cmap, norm = _plot_matrix_elements(
98+
defect.waveder.cder,
99+
d_eigs,
100+
defect_band_index=defect_band_index,
101+
ijdirs=ijdirs,
102+
ax=ax_,
103+
x0=x0,
104+
x_width=x_width,
105+
cmap=cmap,
106+
norm=norm,
107+
)
108+
return _get_dataframe(d_eigs=d_eigs, me_plot_data=me_plot_data), cmap, norm
109+
110+
111+
def get_bs_eigenvalues(
112+
defect: HarmonicDefect,
113+
kpt_index: int = 0,
114+
band_window: int = 5,
115+
user_defect_band: tuple = tuple(),
116+
shift_eig: dict[tuple, float] = None,
117+
) -> dict[tuple, float]:
118+
"""Read the eigenvalues from `HarmonicDefect.relaxed_bandstructure`.
119+
120+
Args:
121+
defect:
122+
The HarmonicDefect object, the `relaxed_bandstructure` attribute
123+
must be set since this contains the eigenvalues.
124+
Please see the `store_bandstructure` option in the constructor.
125+
kpt_index:
126+
The kpoint index to read the eigenvalues from.
127+
band_window:
128+
The number of bands above and below the Fermi level to include.
129+
user_defect_band:
130+
(band, kpt, spin) tuple to specify the defect state. If not provided,
131+
the defect state will be determined automatically using the inverse
132+
participation ratio.
133+
The user provided kpoint index here will overwrite the kpt_index argument.
134+
135+
Returns:
136+
Dictionary of the format: (iband, ikpt, ispin) -> eigenvalue
137+
"""
138+
139+
if defect.relaxed_bandstructure is None: # pragma: no cover
140+
raise ValueError("The defect object does not have a band structure.")
141+
142+
if user_defect_band:
143+
def_indices = user_defect_band
144+
else:
145+
def_indices = next(filter(lambda x: x[1] == kpt_index, defect.defect_band))
146+
147+
band_index, kpt_index, spin_index = def_indices
148+
spin_key = Spin.up if spin_index == 0 else Spin.down
149+
output: dict[tuple, float] = dict()
150+
shift_dict: dict = collections.defaultdict(lambda: 0.0)
151+
if shift_eig is not None:
152+
shift_dict.update(shift_eig)
153+
for ib in range(band_index - band_window, band_index + band_window + 1):
154+
output[(ib, kpt_index, spin_index)] = (
155+
defect.relaxed_bandstructure.bands[spin_key][ib, kpt_index]
156+
+ shift_dict[(ib, kpt_index, spin_index)]
157+
)
158+
return output
159+
160+
161+
def _plot_eigs(
162+
d_eigs: dict[tuple, float],
163+
e_fermi=None,
164+
ax=None,
165+
x0: float = 0.0,
166+
x_width: float = 0.3,
167+
**kwargs,
168+
) -> None:
169+
"""Plot the eigenvalues."""
170+
if ax is None: # pragma: no cover
171+
ax = plt.gca()
172+
173+
# Use current color scheme
174+
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
175+
collections.defaultdict(list)
176+
eigenvalues = np.array(list(d_eigs.values()))
177+
if e_fermi is None: # pragma: no cover
178+
e_fermi = -np.inf
179+
180+
eigs_ = eigenvalues[eigenvalues <= e_fermi]
181+
ax.hlines(
182+
eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[0], **kwargs
183+
)
184+
eigs_ = eigenvalues[eigenvalues > e_fermi]
185+
ax.hlines(
186+
eigs_, x0 - (x_width / 2.0), x0 + (x_width / 2.0), color=colors[1], **kwargs
187+
)
188+
189+
# turn off x-aixs
190+
ax.get_xaxis().set_visible(False)
191+
192+
193+
def _plot_matrix_elements(
194+
cder,
195+
d_eig,
196+
defect_band_index,
197+
ijdirs=((0, 0), (1, 1), (2, 2)),
198+
ax=None,
199+
x0=0,
200+
x_width=0.6,
201+
arrow_width=0.1,
202+
cmap=None,
203+
norm=None,
204+
):
205+
"""Plot arrow for the transition from the defect state to all other states.
206+
207+
Args:
208+
cder:
209+
The matrix element (derivative of the wavefunction) for the defect state.
210+
d_eig:
211+
The dictionary of eigenvalues for the defect state. In the format of
212+
(iband, ikpt, ispin) -> eigenvalue
213+
defect_band_index:
214+
The band index of the defect state.
215+
ax:
216+
The matplotlib axis object to plot on.
217+
x0:
218+
The x coordinate of the center of the set of arrows.
219+
x_width:
220+
The width of the set of arrows.
221+
arrow_width:
222+
The width of the arrow.
223+
cmap:
224+
The matplotlib color map to use.
225+
norm:
226+
The matplotlib normalization to use for the color map.
227+
ijdirs:
228+
The cartesian direction of the WAVDER tensor to sum over for the plot.
229+
If not provided, all the absolute values of the matrix for all
230+
three diagonal entries will be summed.
231+
"""
232+
if ax is None: # pragma: no cover
233+
ax = plt.gca()
234+
ax.set_aspect("equal")
235+
jb, jkpt, jspin = next(filter(lambda x: x[0] == defect_band_index, d_eig.keys()))
236+
y0 = d_eig[jb, jkpt, jspin]
237+
plot_data = []
238+
for (ib, ik, ispin), eig in d_eig.items():
239+
A = 0
240+
for idir, jdir in ijdirs:
241+
A += np.abs(
242+
cder[ib, jb, ik, ispin, idir]
243+
* np.conjugate(cder[ib, jb, ik, ispin, jdir])
244+
)
245+
plot_data.append((jb, ib, eig, A))
246+
247+
if cmap is None:
248+
cmap = plt.get_cmap("viridis")
249+
250+
# get the range of A values
251+
if norm is None:
252+
A_min, A_max = (
253+
min(plot_data, key=lambda x: x[3])[3],
254+
max(plot_data, key=lambda x: x[3])[3],
255+
)
256+
norm = Normalize(vmin=A_min, vmax=A_max)
257+
258+
n_arrows = len(plot_data)
259+
x_step = x_width / n_arrows
260+
x = x0 - x_width / 2 + x_step / 2
261+
for ib, jb, eig, A in plot_data:
262+
ax.arrow(
263+
x=x,
264+
y=y0,
265+
dx=0,
266+
dy=eig - y0,
267+
width=arrow_width,
268+
length_includes_head=True,
269+
head_width=arrow_width * 2,
270+
head_length=arrow_width * 2,
271+
color=cmap(norm(A)),
272+
zorder=20,
273+
)
274+
x += x_step
275+
return plot_data, cmap, norm
276+
277+
278+
def _get_dataframe(d_eigs, me_plot_data) -> pd.DataFrame:
279+
"""Convert the eigenvalue and matrix element data into a pandas dataframe."""
280+
_, ikpt, ispin = next(iter(d_eigs.keys()))
281+
df = pd.DataFrame(
282+
me_plot_data,
283+
columns=["ib", "jb", "eig", "M.E."],
284+
)
285+
df["kpt"] = ikpt
286+
df["spin"] = ispin
287+
return df

0 commit comments

Comments
 (0)