1
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
2
3
3
from copy import deepcopy
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass , field
5
5
import warnings
6
6
7
7
from astropy .modeling import fitting , models
10
10
from scipy .interpolate import UnivariateSpline
11
11
import numpy as np
12
12
13
- __all__ = ['Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
13
+ __all__ = ['BaseTrace' , ' Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
14
14
15
15
16
- @dataclass
17
- class Trace :
16
+ @dataclass ( frozen = True )
17
+ class BaseTrace :
18
18
"""
19
- Basic tracing class that by default traces the middle of the image.
20
-
21
- Parameters
22
- ----------
23
- image : `~astropy.nddata.CCDData`
24
- Image to be traced
25
-
26
- Properties
27
- ----------
28
- shape : tuple
29
- Shape of the array describing the trace
19
+ A dataclass common to all Trace objects.
30
20
"""
31
21
image : CCDData
22
+ _trace_pos : (float , np .ndarray ) = field (repr = False )
23
+ _trace : np .ndarray = field (repr = False )
32
24
33
25
def __post_init__ (self ):
34
- self .trace_pos = self .image .shape [0 ] / 2
35
- self .trace = np .ones_like (self .image [0 ]) * self .trace_pos
26
+ # this class only exists to catch __post_init__ calls in its
27
+ # subclasses, so that super().__post_init__ calls work correctly.
28
+ pass
36
29
37
30
def __getitem__ (self , i ):
38
31
return self .trace [i ]
39
32
40
- @property
41
- def shape (self ):
42
- return self .trace .shape
43
-
44
- def shift (self , delta ):
45
- """
46
- Shift the trace by delta pixels perpendicular to the axis being traced
47
-
48
- Parameters
49
- ----------
50
- delta : float
51
- Shift to be applied to the trace
52
- """
53
- # act on self.trace.data to ignore the mask and then re-mask when calling _bound_trace
54
- self .trace = np .asarray (self .trace .data ) + delta
55
- self ._bound_trace ()
56
-
57
33
def _bound_trace (self ):
58
34
"""
59
35
Mask trace positions that are outside the upper/lower bounds of the image.
60
36
"""
61
37
ny = self .image .shape [0 ]
62
- self . trace = np .ma .masked_outside (self .trace , 0 , ny - 1 )
38
+ object . __setattr__ ( self , '_trace' , np .ma .masked_outside (self ._trace , 0 , ny - 1 ) )
63
39
64
40
def __add__ (self , delta ):
65
41
"""
@@ -77,9 +53,60 @@ def __sub__(self, delta):
77
53
"""
78
54
return self .__add__ (- delta )
79
55
56
+ def shift (self , delta ):
57
+ """
58
+ Shift the trace by delta pixels perpendicular to the axis being traced
59
+
60
+ Parameters
61
+ ----------
62
+ delta : float
63
+ Shift to be applied to the trace
64
+ """
65
+ # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
66
+ object .__setattr__ (self , '_trace' , np .asarray (self ._trace .data ) + delta )
67
+ object .__setattr__ (self , '_trace_pos' , self ._trace_pos + delta )
68
+ self ._bound_trace ()
69
+
70
+ @property
71
+ def shape (self ):
72
+ return self ._trace .shape
73
+
74
+ @property
75
+ def trace (self ):
76
+ return self ._trace
80
77
81
- @dataclass
82
- class FlatTrace (Trace ):
78
+ @property
79
+ def trace_pos (self ):
80
+ return self ._trace_pos
81
+
82
+ @staticmethod
83
+ def _default_trace_attrs (image ):
84
+ """
85
+ Compute a default trace position and trace array using only
86
+ the image dimensions.
87
+ """
88
+ trace_pos = image .shape [0 ] / 2
89
+ trace = np .ones_like (image [0 ]) * trace_pos
90
+ return trace_pos , trace
91
+
92
+
93
+ @dataclass (init = False , frozen = True )
94
+ class Trace (BaseTrace ):
95
+ """
96
+ Basic tracing class that by default traces the middle of the image.
97
+
98
+ Parameters
99
+ ----------
100
+ image : `~astropy.nddata.CCDData`
101
+ Image to be traced
102
+ """
103
+ def __init__ (self , image ):
104
+ trace_pos , trace = self ._default_trace_attrs (image )
105
+ super ().__init__ (image , trace_pos , trace )
106
+
107
+
108
+ @dataclass (init = False , frozen = True )
109
+ class FlatTrace (BaseTrace ):
83
110
"""
84
111
Trace that is constant along the axis being traced
85
112
@@ -92,10 +119,11 @@ class FlatTrace(Trace):
92
119
trace_pos : float
93
120
Position of the trace
94
121
"""
95
- trace_pos : float
96
122
97
- def __post_init__ (self ):
98
- self .set_position (self .trace_pos )
123
+ def __init__ (self , image , trace_pos ):
124
+ _ , trace = self ._default_trace_attrs (image )
125
+ super ().__init__ (image , trace_pos , trace )
126
+ self .set_position (trace_pos )
99
127
100
128
def set_position (self , trace_pos ):
101
129
"""
@@ -106,13 +134,13 @@ def set_position(self, trace_pos):
106
134
trace_pos : float
107
135
Position of the trace
108
136
"""
109
- self . trace_pos = trace_pos
110
- self . trace = np .ones_like (self .image [0 ]) * self . trace_pos
137
+ object . __setattr__ ( self , '_trace_pos' , trace_pos )
138
+ object . __setattr__ ( self , '_trace' , np .ones_like (self .image [0 ]) * trace_pos )
111
139
self ._bound_trace ()
112
140
113
141
114
- @dataclass
115
- class ArrayTrace (Trace ):
142
+ @dataclass ( init = False , frozen = True )
143
+ class ArrayTrace (BaseTrace ):
116
144
"""
117
145
Define a trace given an array of trace positions
118
146
@@ -121,25 +149,27 @@ class ArrayTrace(Trace):
121
149
trace : `numpy.ndarray`
122
150
Array containing trace positions
123
151
"""
124
- trace : np .ndarray
152
+ def __init__ (self , image , trace ):
153
+ trace_pos , _ = self ._default_trace_attrs (image )
154
+ super ().__init__ (image , trace_pos , trace )
125
155
126
- def __post_init__ (self ):
127
156
nx = self .image .shape [1 ]
128
- nt = len (self . trace )
157
+ nt = len (trace )
129
158
if nt != nx :
130
159
if nt > nx :
131
160
# truncate trace to fit image
132
- self . trace = self . trace [0 :nx ]
161
+ trace = trace [0 :nx ]
133
162
else :
134
163
# assume trace starts at beginning of image and pad out trace to fit.
135
164
# padding will be the last value of the trace, but will be masked out.
136
- padding = np .ma .MaskedArray (np .ones (nx - nt ) * self .trace [- 1 ], mask = True )
137
- self .trace = np .ma .hstack ([self .trace , padding ])
165
+ padding = np .ma .MaskedArray (np .ones (nx - nt ) * trace [- 1 ], mask = True )
166
+ trace = np .ma .hstack ([trace , padding ])
167
+ object .__setattr__ (self , '_trace' , trace )
138
168
self ._bound_trace ()
139
169
140
170
141
- @dataclass
142
- class KosmosTrace (Trace ):
171
+ @dataclass ( init = False , frozen = True )
172
+ class KosmosTrace (BaseTrace ):
143
173
"""
144
174
Trace the spectrum aperture in an image.
145
175
@@ -192,14 +222,25 @@ class KosmosTrace(Trace):
192
222
4) add other interpolation modes besides spline, maybe via
193
223
specutils.manipulation methods?
194
224
"""
195
- bins : int = 20
196
- guess : float = None
197
- window : int = None
198
- peak_method : str = 'gaussian'
225
+ bins : int
226
+ guess : float
227
+ window : int
228
+ peak_method : str
199
229
_crossdisp_axis = 0
200
230
_disp_axis = 1
201
231
202
- def __post_init__ (self ):
232
+ def _process_init_kwargs (self , ** kwargs ):
233
+ for attr , value in kwargs .items ():
234
+ object .__setattr__ (self , attr , value )
235
+
236
+ def __init__ (self , image , bins = 20 , guess = None , window = None , peak_method = 'gaussian' ):
237
+ # This method will assign the user supplied value (or default) to the attrs:
238
+ self ._process_init_kwargs (
239
+ bins = bins , guess = guess , window = window , peak_method = peak_method
240
+ )
241
+ trace_pos , trace = self ._default_trace_attrs (image )
242
+ super ().__init__ (image , trace_pos , trace )
243
+
203
244
# handle multiple image types and mask uncaught invalid values
204
245
if isinstance (self .image , NDData ):
205
246
img = np .ma .masked_invalid (np .ma .masked_array (self .image .data ,
@@ -223,7 +264,7 @@ def __post_init__(self):
223
264
224
265
if not isinstance (self .bins , int ):
225
266
warnings .warn ('TRACE: Converting bins to int' )
226
- self . bins = int (self .bins )
267
+ object . __setattr__ ( self , ' bins' , int (self .bins ) )
227
268
228
269
if self .bins < 4 :
229
270
raise ValueError ('bins must be >= 4' )
@@ -240,7 +281,7 @@ def __post_init__(self):
240
281
"length of the image's spatial direction" )
241
282
elif self .window is not None and not isinstance (self .window , int ):
242
283
warnings .warn ('TRACE: Converting window to int' )
243
- self . window = int (self .window )
284
+ object . __setattr__ ( self , ' window' , int (self .window ) )
244
285
245
286
# set max peak location by user choice or wavelength with max avg flux
246
287
ztot = img .sum (axis = self ._disp_axis ) / img .shape [self ._disp_axis ]
@@ -343,4 +384,4 @@ def __post_init__(self):
343
384
warnings .warn ("TRACE ERROR: No valid points found in trace" )
344
385
trace_y = np .tile (np .nan , len (x_bins ))
345
386
346
- self . trace = np .ma .masked_invalid (trace_y )
387
+ object . __setattr__ ( self , '_trace' , np .ma .masked_invalid (trace_y ) )
0 commit comments