21
21
__all__ = ["WavelengthCalibration1D" ]
22
22
23
23
24
+ def unclutter_text_boxes (labels ):
25
+ to_remove = set ()
26
+ for i in range (len (labels )):
27
+ for j in range (i + 1 , len (labels )):
28
+ l1 = labels [i ]
29
+ l2 = labels [j ]
30
+ bbox1 = l1 .get_window_extent ()
31
+ bbox2 = l2 .get_window_extent ()
32
+ if bbox1 .overlaps (bbox2 ):
33
+ if l1 .zorder < l2 .zorder :
34
+ to_remove .add (l1 )
35
+ else :
36
+ to_remove .add (l2 )
37
+
38
+ for label in to_remove :
39
+ label .remove ()
40
+
41
+
24
42
def _diff_poly1d (m : models .Polynomial1D ) -> models .Polynomial1D :
25
43
"""Compute the derivative of a Polynomial1D model.
26
44
@@ -216,10 +234,13 @@ def find_lines(self, fwhm: float, noise_factor: float = 1.0) -> None:
216
234
217
235
with warnings .catch_warnings ():
218
236
warnings .simplefilter ("ignore" )
219
- lines_obs = [
237
+ line_lists = [
220
238
find_arc_lines (sp , fwhm , noise_factor = noise_factor ) for sp in self .arc_spectra
221
239
]
222
- self ._obs_lines = [np .ma .masked_array (lo ["centroid" ].value ) for lo in lines_obs ]
240
+ self .observed_lines = [
241
+ np .ma .masked_array (np .transpose ([ll ["centroid" ].value , ll ["amplitude" ].value ]))
242
+ for ll in line_lists
243
+ ]
223
244
224
245
def fit_lines (
225
246
self ,
@@ -288,7 +309,7 @@ def fit_lines(
288
309
if match_obs :
289
310
if self ._obs_lines is None :
290
311
raise ValueError ("Cannot fit without observed lines set." )
291
- tree = KDTree (np .concatenate ([c .data for c in self ._obs_lines ])[:, None ])
312
+ tree = KDTree (np .concatenate ([c .data [:, 0 ] for c in self ._obs_lines ])[:, None ])
292
313
ix = tree .query (pixels [:, None ])[1 ]
293
314
pixels = tree .data [ix ][:, 0 ]
294
315
@@ -417,7 +438,7 @@ def refine_fit(self, max_match_distance: float = 5.0, max_iter: int = 5) -> None
417
438
rms = np .nan
418
439
for i in range (max_iter ):
419
440
self .match_lines (max_match_distance )
420
- matched_pix = np .ma .concatenate (self ._obs_lines ).compressed ()
441
+ matched_pix = np .ma .concatenate (self .observed_lines ).compressed ()
421
442
matched_wav = np .ma .concatenate (self ._cat_lines ).compressed ()
422
443
rms_new = np .sqrt (((matched_wav - self .pix_to_wav (matched_pix )) ** 2 ).mean ())
423
444
if rms_new == rms :
@@ -557,22 +578,35 @@ def wav_to_pix(self, wav: MaskedArray | ndarray | float) -> ndarray | float:
557
578
@property
558
579
def observed_lines (self ) -> list [MaskedArray ]:
559
580
"""Pixel positions of the observed lines as a list of masked arrays."""
560
- return self ._obs_lines
581
+ return [ lines [:, 0 ] for lines in self ._obs_lines ]
561
582
562
583
@observed_lines .setter
563
- def observed_lines (self , lines_pix : MaskedArray | ndarray | list [MaskedArray ] | list [ndarray ]):
564
- if not isinstance (lines_pix , Sequence ):
565
- lines_pix = [lines_pix ]
584
+ def observed_lines (self , line_lists : MaskedArray | ndarray | list [MaskedArray ] | list [ndarray ]):
585
+ if not isinstance (line_lists , Sequence ):
586
+ line_lists = [line_lists ]
587
+
566
588
self ._obs_lines = []
567
- for lst in lines_pix :
568
- if isinstance (lst , MaskedArray ) and lst .mask is not np .False_ :
569
- self ._obs_lines .append (lst )
570
- else :
571
- self ._obs_lines .append (np .ma .masked_array (lst , mask = np .zeros (lst .size , bool )))
589
+ for lst in line_lists :
590
+ lst = MaskedArray (lst , copy = True )
591
+
592
+ if (lst .ndim > 2 ) or (lst .ndim == 2 and lst .shape [1 ] > 2 ):
593
+ raise ValueError (
594
+ "Observed line lists must be 1D with a shape [n] (centroids) or "
595
+ "2D with a shape [n, 2] (centroids and amplitudes)."
596
+ )
597
+
598
+ if lst .mask is np .False_ :
599
+ lst .mask = np .zeros (lst .shape [0 ], dtype = bool )
600
+
601
+ if lst .ndim == 1 :
602
+ lst = np .tile (lst [:, None ], [1 , 2 ])
603
+ lst [:, 1 ] = np .arange (lst .shape [0 ])
604
+
605
+ self ._obs_lines .append (lst )
572
606
573
607
@property
574
608
def catalog_lines (self ) -> list [MaskedArray ]:
575
- """Catalogue line wavelengths as a list of masked arrays."""
609
+ """Catalog line wavelengths as a list of masked arrays."""
576
610
return self ._cat_lines
577
611
578
612
@catalog_lines .setter
@@ -615,11 +649,10 @@ def match_lines(self, max_distance: float = 5) -> None:
615
649
The maximum allowed distance between the query points and the KD-tree
616
650
data points for them to be considered a match.
617
651
"""
618
- matched_lines_wav = []
619
- matched_lines_pix = []
652
+
620
653
for iframe , tree in enumerate (self ._trees ):
621
654
l , ix = tree .query (
622
- self ._p2w (self ._obs_lines [iframe ].data )[:, None ],
655
+ self ._p2w (self ._obs_lines [iframe ].data [:, 0 ] )[:, None ],
623
656
distance_upper_bound = max_distance ,
624
657
)
625
658
m = np .isfinite (l )
@@ -636,12 +669,9 @@ def match_lines(self, max_distance: float = 5) -> None:
636
669
r [np .argmin (l [s ])] = True
637
670
m [s ] = r
638
671
639
- matched_lines_wav .append (np .ma .masked_array (tree .data [:, 0 ], mask = True ))
640
- matched_lines_wav [- 1 ].mask [ix [m ]] = False
641
- matched_lines_pix .append (np .ma .masked_array (self ._obs_lines [iframe ].data , mask = ~ m ))
642
-
643
- self ._obs_lines = matched_lines_pix
644
- self ._cat_lines = matched_lines_wav
672
+ self ._cat_lines [iframe ].mask [:] = True
673
+ self ._cat_lines [iframe ].mask [ix [m ]] = False
674
+ self ._obs_lines [iframe ].mask [:, :] = ~ m [:, None ]
645
675
646
676
def remove_ummatched_lines (self ):
647
677
"""Remove unmatched lines from observation and catalog line data."""
@@ -804,10 +834,10 @@ def plot_observed_lines(
804
834
frames : int | Sequence [int ] | None = None ,
805
835
axes : Axes | Sequence [Axes ] | None = None ,
806
836
figsize : tuple [float , float ] | None = None ,
807
- plot_values : bool = True ,
837
+ plot_labels : bool = True ,
808
838
plot_spectra : bool = True ,
809
839
map_to_wav : bool = False ,
810
- value_fontsize : int | str | None = "small" ,
840
+ label_kwargs : dict | None = None ,
811
841
) -> Figure :
812
842
"""Plot observed spectral lines for the given arc spectra.
813
843
@@ -822,55 +852,92 @@ def plot_observed_lines(
822
852
figsize
823
853
Dimensions of the figure to be created, specified as a tuple (width, height). Ignored
824
854
if ``axes`` is provided.
825
- plot_values
855
+ plot_labels
826
856
If True, plots the numerical values of the observed lines at their respective
827
857
locations on the graph. Default is True.
828
858
plot_spectra
829
859
If True, includes the arc spectra on the plot for comparison. Default is True.
830
860
map_to_wav
831
861
Determines whether to map the x-axis values to wavelengths. Default is False.
862
+ label_kwargs
863
+ Specifies the keyword arguments for the line label text objects.
832
864
833
865
Returns
834
866
-------
835
867
Figure
836
868
The matplotlib figure containing the observed lines plot.
837
869
"""
838
- fig = self . _plot_lines (
839
- "observed" ,
840
- frames = frames ,
841
- axs = axes ,
842
- figsize = figsize ,
843
- plot_values = plot_values ,
844
- map_x = map_to_wav ,
845
- value_fontsize = value_fontsize ,
846
- )
870
+
871
+ largs = dict ( backgroundcolor = "w" , rotation = 90 , size = "small" )
872
+ if label_kwargs is not None :
873
+ largs . update ( label_kwargs )
874
+
875
+ if frames is None :
876
+ frames = np . arange ( self . nframes )
877
+ else :
878
+ frames = np . atleast_1d ( frames )
847
879
848
880
if axes is None :
849
- axes = np .atleast_1d (fig .axes )
850
-
851
- if self .arc_spectra is not None and plot_spectra :
852
- if frames is None :
853
- frames = np .arange (self .nframes )
854
- elif np .isscalar (frames ):
855
- frames = [frames ]
856
-
857
- transform = self ._p2w if map_to_wav else lambda x : x
858
- for i , frame in enumerate (frames ):
859
- axes [i ].plot (
860
- transform (self .arc_spectra [frame ].spectral_axis .value ),
861
- self .arc_spectra [frame ].data / (1.2 * self .arc_spectra [frame ].data .max ()),
862
- c = "k" ,
863
- zorder = - 10 ,
864
- )
865
- setp (
866
- axes ,
867
- xlim = transform (
881
+ fig , axes = subplots (frames .size , 1 , figsize = figsize , constrained_layout = True )
882
+ elif isinstance (axes , Axes ):
883
+ fig = axes .figure
884
+ axes = [axes ]
885
+ else :
886
+ fig = axes [0 ].figure
887
+ axes = np .atleast_1d (axes )
888
+
889
+ transform = self .pix_to_wav if map_to_wav else lambda x : x
890
+ xlabel = f"Wavelength [{ self ._unit_str } ]" if map_to_wav else "Pixel"
891
+
892
+ ypad = 1.3
893
+
894
+ for iax , iframe in enumerate (frames ):
895
+ ax = axes .flat [iax ]
896
+ if plot_spectra and self .arc_spectra is not None :
897
+ spc = self .arc_spectra [iframe ]
898
+ vmax = spc .flux .value .max ()
899
+ ax .plot (transform (spc .spectral_axis .value ), spc .flux .value / vmax )
900
+ else :
901
+ vmax = 1.0
902
+
903
+ labels = []
904
+ for i in range (self ._obs_lines [iframe ].shape [0 ]):
905
+ c , a = self ._obs_lines [iframe ].data [i ]
906
+ if self ._obs_lines [iframe ].mask [i , 0 ] is True :
907
+ ls = ":"
908
+ else :
909
+ ls = "-"
910
+
911
+ ax .plot (transform ([c , c ]), [a / vmax + 0.02 , 1.27 ], "0.75" , ls = ls )
912
+ if plot_labels :
913
+ labels .append (
914
+ ax .text (
915
+ transform (c ),
916
+ ypad ,
917
+ f"{ transform (c ):.0f} " ,
918
+ ha = "center" ,
919
+ va = "bottom" ,
920
+ ** largs ,
921
+ )
922
+ )
923
+ labels [- 1 ].zorder = a / vmax
924
+
925
+ if plot_labels :
926
+ fig .canvas .draw ()
927
+ unclutter_text_boxes (labels )
928
+ tr = ax .transData .inverted ()
929
+ ymax = max (
868
930
[
869
- self .arc_spectra [0 ].spectral_axis .min ().value ,
870
- self .arc_spectra [0 ].spectral_axis .max ().value ,
931
+ tr .transform_bbox (label .get_window_extent ()).max [1 ]
932
+ for label in labels
933
+ if label .figure is not None
871
934
]
872
- ),
873
- )
935
+ )
936
+ else :
937
+ ymax = ypad
938
+
939
+ setp (ax , xlabel = xlabel , yticks = [], ylim = (- 0.02 , ymax + 0.02 ))
940
+ ax .autoscale (True , "x" , tight = True )
874
941
return fig
875
942
876
943
def plot_fit (
@@ -925,9 +992,8 @@ def plot_fit(
925
992
self .plot_observed_lines (
926
993
frames ,
927
994
axs [1 ::2 ],
928
- plot_values = plot_values ,
995
+ plot_labels = plot_values ,
929
996
map_to_wav = obs_to_wav ,
930
- value_fontsize = value_fontsize ,
931
997
)
932
998
933
999
xlims = np .array ([ax .get_xlim () for ax in axs [::2 ]])
0 commit comments