1
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
2
3
3
from copy import deepcopy
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass , field
5
5
import warnings
6
6
7
7
from astropy .modeling import fitting , models
10
10
from scipy .interpolate import UnivariateSpline
11
11
import numpy as np
12
12
13
- __all__ = ['Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
13
+ __all__ = ['BaseTrace' , ' Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
14
14
15
15
16
- @dataclass
17
- class Trace :
16
+ @dataclass ( frozen = True )
17
+ class BaseTrace :
18
18
"""
19
- Basic tracing class that by default traces the middle of the image.
20
-
21
- Parameters
22
- ----------
23
- image : `~astropy.nddata.CCDData`
24
- Image to be traced
25
-
26
- Properties
27
- ----------
28
- shape : tuple
29
- Shape of the array describing the trace
19
+ A dataclass common to all Trace objects.
30
20
"""
31
21
image : CCDData
22
+ _trace_pos : (float , np .ndarray ) = field (repr = False )
23
+ _trace : np .ndarray = field (repr = False )
32
24
33
25
def __post_init__ (self ):
34
- self .trace_pos = self .image .shape [0 ] / 2
35
- self .trace = np .ones_like (self .image [0 ]) * self .trace_pos
26
+ # this class only exists to catch __post_init__ calls in its
27
+ # subclasses, so that super().__post_init__ calls work correctly.
28
+ pass
36
29
37
30
def __getitem__ (self , i ):
38
31
return self .trace [i ]
@@ -59,7 +52,7 @@ def _bound_trace(self):
59
52
Mask trace positions that are outside the upper/lower bounds of the image.
60
53
"""
61
54
ny = self .image .shape [0 ]
62
- self . trace = np .ma .masked_outside (self .trace , 0 , ny - 1 )
55
+ object . __setattr__ ( self , '_trace' , np .ma .masked_outside (self ._trace , 0 , ny - 1 ) )
63
56
64
57
def __add__ (self , delta ):
65
58
"""
@@ -77,9 +70,60 @@ def __sub__(self, delta):
77
70
"""
78
71
return self .__add__ (- delta )
79
72
73
+ def shift (self , delta ):
74
+ """
75
+ Shift the trace by delta pixels perpendicular to the axis being traced
76
+
77
+ Parameters
78
+ ----------
79
+ delta : float
80
+ Shift to be applied to the trace
81
+ """
82
+ # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
83
+ object .__setattr__ (self , '_trace' , np .asarray (self ._trace .data ) + delta )
84
+ object .__setattr__ (self , '_trace_pos' , self ._trace_pos + delta )
85
+ self ._bound_trace ()
86
+
87
+ @property
88
+ def shape (self ):
89
+ return self ._trace .shape
90
+
91
+ @property
92
+ def trace (self ):
93
+ return self ._trace
94
+
95
+ @property
96
+ def trace_pos (self ):
97
+ return self ._trace_pos
98
+
99
+ @staticmethod
100
+ def _default_trace_attrs (image ):
101
+ """
102
+ Compute a default trace position and trace array using only
103
+ the image dimensions.
104
+ """
105
+ trace_pos = image .shape [0 ] / 2
106
+ trace = np .ones_like (image [0 ]) * trace_pos
107
+ return trace_pos , trace
108
+
109
+
110
+ @dataclass (init = False , frozen = True )
111
+ class Trace (BaseTrace ):
112
+ """
113
+ Basic tracing class that by default traces the middle of the image.
114
+
115
+ Parameters
116
+ ----------
117
+ image : `~astropy.nddata.CCDData`
118
+ Image to be traced
119
+ """
120
+ def __init__ (self , image ):
121
+ trace_pos , trace = self ._default_trace_attrs (image )
122
+ super ().__init__ (image , trace_pos , trace )
123
+
80
124
81
- @dataclass
82
- class FlatTrace (Trace ):
125
+ @dataclass ( init = False , frozen = True )
126
+ class FlatTrace (BaseTrace ):
83
127
"""
84
128
Trace that is constant along the axis being traced
85
129
@@ -92,10 +136,11 @@ class FlatTrace(Trace):
92
136
trace_pos : float
93
137
Position of the trace
94
138
"""
95
- trace_pos : float
96
139
97
- def __post_init__ (self ):
98
- self .set_position (self .trace_pos )
140
+ def __init__ (self , image , trace_pos ):
141
+ _ , trace = self ._default_trace_attrs (image )
142
+ super ().__init__ (image , trace_pos , trace )
143
+ self .set_position (trace_pos )
99
144
100
145
def set_position (self , trace_pos ):
101
146
"""
@@ -106,13 +151,13 @@ def set_position(self, trace_pos):
106
151
trace_pos : float
107
152
Position of the trace
108
153
"""
109
- self . trace_pos = trace_pos
110
- self . trace = np .ones_like (self .image [0 ]) * self . trace_pos
154
+ object . __setattr__ ( self , '_trace_pos' , trace_pos )
155
+ object . __setattr__ ( self , '_trace' , np .ones_like (self .image [0 ]) * trace_pos )
111
156
self ._bound_trace ()
112
157
113
158
114
- @dataclass
115
- class ArrayTrace (Trace ):
159
+ @dataclass ( init = False , frozen = True )
160
+ class ArrayTrace (BaseTrace ):
116
161
"""
117
162
Define a trace given an array of trace positions
118
163
@@ -121,25 +166,27 @@ class ArrayTrace(Trace):
121
166
trace : `numpy.ndarray`
122
167
Array containing trace positions
123
168
"""
124
- trace : np .ndarray
169
+ def __init__ (self , image , trace ):
170
+ trace_pos , _ = self ._default_trace_attrs (image )
171
+ super ().__init__ (image , trace_pos , trace )
125
172
126
- def __post_init__ (self ):
127
173
nx = self .image .shape [1 ]
128
- nt = len (self . trace )
174
+ nt = len (trace )
129
175
if nt != nx :
130
176
if nt > nx :
131
177
# truncate trace to fit image
132
- self . trace = self . trace [0 :nx ]
178
+ trace = trace [0 :nx ]
133
179
else :
134
180
# assume trace starts at beginning of image and pad out trace to fit.
135
181
# padding will be the last value of the trace, but will be masked out.
136
- padding = np .ma .MaskedArray (np .ones (nx - nt ) * self .trace [- 1 ], mask = True )
137
- self .trace = np .ma .hstack ([self .trace , padding ])
182
+ padding = np .ma .MaskedArray (np .ones (nx - nt ) * trace [- 1 ], mask = True )
183
+ trace = np .ma .hstack ([trace , padding ])
184
+ object .__setattr__ (self , '_trace' , trace )
138
185
self ._bound_trace ()
139
186
140
187
141
- @dataclass
142
- class KosmosTrace (Trace ):
188
+ @dataclass ( init = False , frozen = True )
189
+ class KosmosTrace (BaseTrace ):
143
190
"""
144
191
Trace the spectrum aperture in an image.
145
192
@@ -192,14 +239,25 @@ class KosmosTrace(Trace):
192
239
4) add other interpolation modes besides spline, maybe via
193
240
specutils.manipulation methods?
194
241
"""
195
- bins : int = 20
196
- guess : float = None
197
- window : int = None
198
- peak_method : str = 'gaussian'
242
+ bins : int
243
+ guess : float
244
+ window : int
245
+ peak_method : str
199
246
_crossdisp_axis = 0
200
247
_disp_axis = 1
201
248
202
- def __post_init__ (self ):
249
+ def _process_init_kwargs (self , ** kwargs ):
250
+ for attr , value in kwargs .items ():
251
+ object .__setattr__ (self , attr , value )
252
+
253
+ def __init__ (self , image , bins = 20 , guess = None , window = None , peak_method = 'gaussian' ):
254
+ # This method will assign the user supplied value (or default) to the attrs:
255
+ self ._process_init_kwargs (
256
+ bins = bins , guess = guess , window = window , peak_method = peak_method
257
+ )
258
+ trace_pos , trace = self ._default_trace_attrs (image )
259
+ super ().__init__ (image , trace_pos , trace )
260
+
203
261
# handle multiple image types and mask uncaught invalid values
204
262
if isinstance (self .image , NDData ):
205
263
img = np .ma .masked_invalid (np .ma .masked_array (self .image .data ,
@@ -223,7 +281,7 @@ def __post_init__(self):
223
281
224
282
if not isinstance (self .bins , int ):
225
283
warnings .warn ('TRACE: Converting bins to int' )
226
- self . bins = int (self .bins )
284
+ object . __setattr__ ( self , ' bins' , int (self .bins ) )
227
285
228
286
if self .bins < 4 :
229
287
raise ValueError ('bins must be >= 4' )
@@ -240,7 +298,7 @@ def __post_init__(self):
240
298
"length of the image's spatial direction" )
241
299
elif self .window is not None and not isinstance (self .window , int ):
242
300
warnings .warn ('TRACE: Converting window to int' )
243
- self . window = int (self .window )
301
+ object . __setattr__ ( self , ' window' , int (self .window ) )
244
302
245
303
# set max peak location by user choice or wavelength with max avg flux
246
304
ztot = img .sum (axis = self ._disp_axis ) / img .shape [self ._disp_axis ]
@@ -343,4 +401,4 @@ def __post_init__(self):
343
401
warnings .warn ("TRACE ERROR: No valid points found in trace" )
344
402
trace_y = np .tile (np .nan , len (x_bins ))
345
403
346
- self . trace = np .ma .masked_invalid (trace_y )
404
+ object . __setattr__ ( self , '_trace' , np .ma .masked_invalid (trace_y ) )
0 commit comments