Skip to content

Commit 786bded

Browse files
committed
Refactored observed line handling in wavecal1d.py, added amplitude support, and improved the plotting of observed lines.
1 parent 538d2b8 commit 786bded

File tree

1 file changed

+126
-60
lines changed

1 file changed

+126
-60
lines changed

specreduce/wavecal1d.py

Lines changed: 126 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@
2121
__all__ = ["WavelengthCalibration1D"]
2222

2323

24+
def unclutter_text_boxes(labels):
25+
to_remove = set()
26+
for i in range(len(labels)):
27+
for j in range(i + 1, len(labels)):
28+
l1 = labels[i]
29+
l2 = labels[j]
30+
bbox1 = l1.get_window_extent()
31+
bbox2 = l2.get_window_extent()
32+
if bbox1.overlaps(bbox2):
33+
if l1.zorder < l2.zorder:
34+
to_remove.add(l1)
35+
else:
36+
to_remove.add(l2)
37+
38+
for label in to_remove:
39+
label.remove()
40+
41+
2442
def _diff_poly1d(m: models.Polynomial1D) -> models.Polynomial1D:
2543
"""Compute the derivative of a Polynomial1D model.
2644
@@ -216,10 +234,13 @@ def find_lines(self, fwhm: float, noise_factor: float = 1.0) -> None:
216234

217235
with warnings.catch_warnings():
218236
warnings.simplefilter("ignore")
219-
lines_obs = [
237+
line_lists = [
220238
find_arc_lines(sp, fwhm, noise_factor=noise_factor) for sp in self.arc_spectra
221239
]
222-
self._obs_lines = [np.ma.masked_array(lo["centroid"].value) for lo in lines_obs]
240+
self.observed_lines = [
241+
np.ma.masked_array(np.transpose([ll["centroid"].value, ll["amplitude"].value]))
242+
for ll in line_lists
243+
]
223244

224245
def fit_lines(
225246
self,
@@ -288,7 +309,7 @@ def fit_lines(
288309
if match_obs:
289310
if self._obs_lines is None:
290311
raise ValueError("Cannot fit without observed lines set.")
291-
tree = KDTree(np.concatenate([c.data for c in self._obs_lines])[:, None])
312+
tree = KDTree(np.concatenate([c.data[:, 0] for c in self._obs_lines])[:, None])
292313
ix = tree.query(pixels[:, None])[1]
293314
pixels = tree.data[ix][:, 0]
294315

@@ -417,7 +438,7 @@ def refine_fit(self, max_match_distance: float = 5.0, max_iter: int = 5) -> None
417438
rms = np.nan
418439
for i in range(max_iter):
419440
self.match_lines(max_match_distance)
420-
matched_pix = np.ma.concatenate(self._obs_lines).compressed()
441+
matched_pix = np.ma.concatenate(self.observed_lines).compressed()
421442
matched_wav = np.ma.concatenate(self._cat_lines).compressed()
422443
rms_new = np.sqrt(((matched_wav - self.pix_to_wav(matched_pix)) ** 2).mean())
423444
if rms_new == rms:
@@ -557,22 +578,35 @@ def wav_to_pix(self, wav: MaskedArray | ndarray | float) -> ndarray | float:
557578
@property
558579
def observed_lines(self) -> list[MaskedArray]:
559580
"""Pixel positions of the observed lines as a list of masked arrays."""
560-
return self._obs_lines
581+
return [lines[:, 0] for lines in self._obs_lines]
561582

562583
@observed_lines.setter
563-
def observed_lines(self, lines_pix: MaskedArray | ndarray | list[MaskedArray] | list[ndarray]):
564-
if not isinstance(lines_pix, Sequence):
565-
lines_pix = [lines_pix]
584+
def observed_lines(self, line_lists: MaskedArray | ndarray | list[MaskedArray] | list[ndarray]):
585+
if not isinstance(line_lists, Sequence):
586+
line_lists = [line_lists]
587+
566588
self._obs_lines = []
567-
for lst in lines_pix:
568-
if isinstance(lst, MaskedArray) and lst.mask is not np.False_:
569-
self._obs_lines.append(lst)
570-
else:
571-
self._obs_lines.append(np.ma.masked_array(lst, mask=np.zeros(lst.size, bool)))
589+
for lst in line_lists:
590+
lst = MaskedArray(lst, copy=True)
591+
592+
if (lst.ndim > 2) or (lst.ndim == 2 and lst.shape[1] > 2):
593+
raise ValueError(
594+
"Observed line lists must be 1D with a shape [n] (centroids) or "
595+
"2D with a shape [n, 2] (centroids and amplitudes)."
596+
)
597+
598+
if lst.mask is np.False_:
599+
lst.mask = np.zeros(lst.shape[0], dtype=bool)
600+
601+
if lst.ndim == 1:
602+
lst = np.tile(lst[:, None], [1, 2])
603+
lst[:, 1] = np.arange(lst.shape[0])
604+
605+
self._obs_lines.append(lst)
572606

573607
@property
574608
def catalog_lines(self) -> list[MaskedArray]:
575-
"""Catalogue line wavelengths as a list of masked arrays."""
609+
"""Catalog line wavelengths as a list of masked arrays."""
576610
return self._cat_lines
577611

578612
@catalog_lines.setter
@@ -615,11 +649,10 @@ def match_lines(self, max_distance: float = 5) -> None:
615649
The maximum allowed distance between the query points and the KD-tree
616650
data points for them to be considered a match.
617651
"""
618-
matched_lines_wav = []
619-
matched_lines_pix = []
652+
620653
for iframe, tree in enumerate(self._trees):
621654
l, ix = tree.query(
622-
self._p2w(self._obs_lines[iframe].data)[:, None],
655+
self._p2w(self._obs_lines[iframe].data[:, 0])[:, None],
623656
distance_upper_bound=max_distance,
624657
)
625658
m = np.isfinite(l)
@@ -636,12 +669,9 @@ def match_lines(self, max_distance: float = 5) -> None:
636669
r[np.argmin(l[s])] = True
637670
m[s] = r
638671

639-
matched_lines_wav.append(np.ma.masked_array(tree.data[:, 0], mask=True))
640-
matched_lines_wav[-1].mask[ix[m]] = False
641-
matched_lines_pix.append(np.ma.masked_array(self._obs_lines[iframe].data, mask=~m))
642-
643-
self._obs_lines = matched_lines_pix
644-
self._cat_lines = matched_lines_wav
672+
self._cat_lines[iframe].mask[:] = True
673+
self._cat_lines[iframe].mask[ix[m]] = False
674+
self._obs_lines[iframe].mask[:, :] = ~m[:, None]
645675

646676
def remove_ummatched_lines(self):
647677
"""Remove unmatched lines from observation and catalog line data."""
@@ -804,10 +834,10 @@ def plot_observed_lines(
804834
frames: int | Sequence[int] | None = None,
805835
axes: Axes | Sequence[Axes] | None = None,
806836
figsize: tuple[float, float] | None = None,
807-
plot_values: bool = True,
837+
plot_labels: bool = True,
808838
plot_spectra: bool = True,
809839
map_to_wav: bool = False,
810-
value_fontsize: int | str | None = "small",
840+
label_kwargs: dict | None = None,
811841
) -> Figure:
812842
"""Plot observed spectral lines for the given arc spectra.
813843
@@ -822,55 +852,92 @@ def plot_observed_lines(
822852
figsize
823853
Dimensions of the figure to be created, specified as a tuple (width, height). Ignored
824854
if ``axes`` is provided.
825-
plot_values
855+
plot_labels
826856
If True, plots the numerical values of the observed lines at their respective
827857
locations on the graph. Default is True.
828858
plot_spectra
829859
If True, includes the arc spectra on the plot for comparison. Default is True.
830860
map_to_wav
831861
Determines whether to map the x-axis values to wavelengths. Default is False.
862+
label_kwargs
863+
Specifies the keyword arguments for the line label text objects.
832864
833865
Returns
834866
-------
835867
Figure
836868
The matplotlib figure containing the observed lines plot.
837869
"""
838-
fig = self._plot_lines(
839-
"observed",
840-
frames=frames,
841-
axs=axes,
842-
figsize=figsize,
843-
plot_values=plot_values,
844-
map_x=map_to_wav,
845-
value_fontsize=value_fontsize,
846-
)
870+
871+
largs = dict(backgroundcolor="w", rotation=90, size="small")
872+
if label_kwargs is not None:
873+
largs.update(label_kwargs)
874+
875+
if frames is None:
876+
frames = np.arange(self.nframes)
877+
else:
878+
frames = np.atleast_1d(frames)
847879

848880
if axes is None:
849-
axes = np.atleast_1d(fig.axes)
850-
851-
if self.arc_spectra is not None and plot_spectra:
852-
if frames is None:
853-
frames = np.arange(self.nframes)
854-
elif np.isscalar(frames):
855-
frames = [frames]
856-
857-
transform = self._p2w if map_to_wav else lambda x: x
858-
for i, frame in enumerate(frames):
859-
axes[i].plot(
860-
transform(self.arc_spectra[frame].spectral_axis.value),
861-
self.arc_spectra[frame].data / (1.2 * self.arc_spectra[frame].data.max()),
862-
c="k",
863-
zorder=-10,
864-
)
865-
setp(
866-
axes,
867-
xlim=transform(
881+
fig, axes = subplots(frames.size, 1, figsize=figsize, constrained_layout=True)
882+
elif isinstance(axes, Axes):
883+
fig = axes.figure
884+
axes = [axes]
885+
else:
886+
fig = axes[0].figure
887+
axes = np.atleast_1d(axes)
888+
889+
transform = self.pix_to_wav if map_to_wav else lambda x: x
890+
xlabel = f"Wavelength [{self._unit_str}]" if map_to_wav else "Pixel"
891+
892+
ypad = 1.3
893+
894+
for iax, iframe in enumerate(frames):
895+
ax = axes.flat[iax]
896+
if plot_spectra and self.arc_spectra is not None:
897+
spc = self.arc_spectra[iframe]
898+
vmax = spc.flux.value.max()
899+
ax.plot(transform(spc.spectral_axis.value), spc.flux.value / vmax)
900+
else:
901+
vmax = 1.0
902+
903+
labels = []
904+
for i in range(self._obs_lines[iframe].shape[0]):
905+
c, a = self._obs_lines[iframe].data[i]
906+
if self._obs_lines[iframe].mask[i, 0] is True:
907+
ls = ":"
908+
else:
909+
ls = "-"
910+
911+
ax.plot(transform([c, c]), [a / vmax + 0.02, 1.27], "0.75", ls=ls)
912+
if plot_labels:
913+
labels.append(
914+
ax.text(
915+
transform(c),
916+
ypad,
917+
f"{transform(c):.0f}",
918+
ha="center",
919+
va="bottom",
920+
**largs,
921+
)
922+
)
923+
labels[-1].zorder = a / vmax
924+
925+
if plot_labels:
926+
fig.canvas.draw()
927+
unclutter_text_boxes(labels)
928+
tr = ax.transData.inverted()
929+
ymax = max(
868930
[
869-
self.arc_spectra[0].spectral_axis.min().value,
870-
self.arc_spectra[0].spectral_axis.max().value,
931+
tr.transform_bbox(label.get_window_extent()).max[1]
932+
for label in labels
933+
if label.figure is not None
871934
]
872-
),
873-
)
935+
)
936+
else:
937+
ymax = ypad
938+
939+
setp(ax, xlabel=xlabel, yticks=[], ylim=(-0.02, ymax + 0.02))
940+
ax.autoscale(True, "x", tight=True)
874941
return fig
875942

876943
def plot_fit(
@@ -925,9 +992,8 @@ def plot_fit(
925992
self.plot_observed_lines(
926993
frames,
927994
axs[1::2],
928-
plot_values=plot_values,
995+
plot_labels=plot_values,
929996
map_to_wav=obs_to_wav,
930-
value_fontsize=value_fontsize,
931997
)
932998

933999
xlims = np.array([ax.get_xlim() for ax in axs[::2]])

0 commit comments

Comments
 (0)