Skip to content

Commit 62cbf11

Browse files
committed
Implementing BaseTemplate to manage Trace attr mutability
1 parent 67b6b8b commit 62cbf11

File tree

3 files changed

+112
-71
lines changed

3 files changed

+112
-71
lines changed

specreduce/background.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from astropy import units as u
99

1010
from specreduce.extract import _ap_weight_image, _to_spectrum1d_pixels
11-
from specreduce.tracing import Trace, FlatTrace
11+
from specreduce.tracing import FlatTrace, BaseTrace
1212

1313
__all__ = ['Background']
1414

@@ -77,7 +77,7 @@ def __post_init__(self):
7777
cross-dispersion axis
7878
"""
7979
def _to_trace(trace):
80-
if not isinstance(trace, Trace):
80+
if not isinstance(trace, BaseTrace):
8181
trace = FlatTrace(self.image, trace)
8282

8383
# TODO: this check can be removed if/when implemented as a check in FlatTrace
@@ -93,7 +93,7 @@ def _to_trace(trace):
9393
self.bkg_array = np.zeros(self.image.shape[self.disp_axis])
9494
return
9595

96-
if isinstance(self.traces, Trace):
96+
if isinstance(self.traces, BaseTrace):
9797
self.traces = [self.traces]
9898

9999
bkg_wimage = np.zeros_like(self.image, dtype=np.float64)

specreduce/extract.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from astropy.nddata import NDData
1111

1212
from specreduce.core import SpecreduceOperation
13-
from specreduce.tracing import Trace, FlatTrace
13+
from specreduce.tracing import FlatTrace, BaseTrace
1414
from specutils import Spectrum1D
1515

1616
__all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract']
@@ -88,7 +88,7 @@ def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):
8888
8989
Parameters
9090
----------
91-
trace : `~specreduce.tracing.Trace`, required
91+
trace : `~specreduce.tracing.BaseTrace`, required
9292
trace object
9393
width : float, required
9494
width of extraction aperture in pixels
@@ -139,7 +139,7 @@ class BoxcarExtract(SpecreduceOperation):
139139
----------
140140
image : nddata-compatible image
141141
image with 2-D spectral image data
142-
trace_object : Trace
142+
trace_object : BaseTrace
143143
trace object
144144
width : float
145145
width of extraction aperture in pixels
@@ -154,7 +154,7 @@ class BoxcarExtract(SpecreduceOperation):
154154
The extracted 1d spectrum expressed in DN and pixel units
155155
"""
156156
image: NDData
157-
trace_object: Trace
157+
trace_object: BaseTrace
158158
width: float = 5
159159
disp_axis: int = 1
160160
crossdisp_axis: int = 0
@@ -173,7 +173,7 @@ def __call__(self, image=None, trace_object=None, width=None,
173173
----------
174174
image : nddata-compatible image
175175
image with 2-D spectral image data
176-
trace_object : Trace
176+
trace_object : BaseTrace
177177
trace object
178178
width : float
179179
width of extraction aperture in pixels [default: 5]
@@ -230,7 +230,7 @@ class HorneExtract(SpecreduceOperation):
230230
NDData object must specify uncertainty and a mask. An array
231231
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
232232
233-
trace_object : `~specreduce.tracing.Trace`, required
233+
trace_object : `~specreduce.tracing.BaseTrace`, required
234234
The associated 1D trace object created for the 2D image.
235235
236236
disp_axis : int, optional
@@ -264,7 +264,7 @@ class HorneExtract(SpecreduceOperation):
264264
265265
"""
266266
image: NDData
267-
trace_object: Trace
267+
trace_object: BaseTrace
268268
bkgrd_prof: Model = field(default=models.Polynomial1D(2))
269269
variance: np.ndarray = field(default=None)
270270
mask: np.ndarray = field(default=None)
@@ -293,7 +293,7 @@ def __call__(self, image=None, trace_object=None,
293293
NDData object must specify uncertainty and a mask. An array
294294
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
295295
296-
trace_object : `~specreduce.tracing.Trace`, required
296+
trace_object : `~specreduce.tracing.BaseTrace`, required
297297
The associated 1D trace object created for the 2D image.
298298
299299
disp_axis : int, optional

specreduce/tracing.py

Lines changed: 101 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

33
from copy import deepcopy
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
import warnings
66

77
from astropy.modeling import fitting, models
@@ -10,56 +10,32 @@
1010
from scipy.interpolate import UnivariateSpline
1111
import numpy as np
1212

13-
__all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
13+
__all__ = ['BaseTrace', 'Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
1414

1515

16-
@dataclass
17-
class Trace:
16+
@dataclass(frozen=True)
17+
class BaseTrace:
1818
"""
19-
Basic tracing class that by default traces the middle of the image.
20-
21-
Parameters
22-
----------
23-
image : `~astropy.nddata.CCDData`
24-
Image to be traced
25-
26-
Properties
27-
----------
28-
shape : tuple
29-
Shape of the array describing the trace
19+
A dataclass common to all Trace objects.
3020
"""
3121
image: CCDData
22+
_trace_pos: (float, np.ndarray) = field(repr=False)
23+
_trace: np.ndarray = field(repr=False)
3224

3325
def __post_init__(self):
34-
self.trace_pos = self.image.shape[0] / 2
35-
self.trace = np.ones_like(self.image[0]) * self.trace_pos
26+
# this class only exists to catch __post_init__ calls in its
27+
# subclasses, so that super().__post_init__ calls work correctly.
28+
pass
3629

3730
def __getitem__(self, i):
3831
return self.trace[i]
3932

40-
@property
41-
def shape(self):
42-
return self.trace.shape
43-
44-
def shift(self, delta):
45-
"""
46-
Shift the trace by delta pixels perpendicular to the axis being traced
47-
48-
Parameters
49-
----------
50-
delta : float
51-
Shift to be applied to the trace
52-
"""
53-
# act on self.trace.data to ignore the mask and then re-mask when calling _bound_trace
54-
self.trace = np.asarray(self.trace.data) + delta
55-
self._bound_trace()
56-
5733
def _bound_trace(self):
5834
"""
5935
Mask trace positions that are outside the upper/lower bounds of the image.
6036
"""
6137
ny = self.image.shape[0]
62-
self.trace = np.ma.masked_outside(self.trace, 0, ny-1)
38+
object.__setattr__(self, '_trace', np.ma.masked_outside(self._trace, 0, ny - 1))
6339

6440
def __add__(self, delta):
6541
"""
@@ -77,9 +53,60 @@ def __sub__(self, delta):
7753
"""
7854
return self.__add__(-delta)
7955

56+
def shift(self, delta):
57+
"""
58+
Shift the trace by delta pixels perpendicular to the axis being traced
59+
60+
Parameters
61+
----------
62+
delta : float
63+
Shift to be applied to the trace
64+
"""
65+
# act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
66+
object.__setattr__(self, '_trace', np.asarray(self._trace.data) + delta)
67+
object.__setattr__(self, '_trace_pos', self._trace_pos + delta)
68+
self._bound_trace()
69+
70+
@property
71+
def shape(self):
72+
return self._trace.shape
73+
74+
@property
75+
def trace(self):
76+
return self._trace
8077

81-
@dataclass
82-
class FlatTrace(Trace):
78+
@property
79+
def trace_pos(self):
80+
return self._trace_pos
81+
82+
@staticmethod
83+
def _default_trace_attrs(image):
84+
"""
85+
Compute a default trace position and trace array using only
86+
the image dimensions.
87+
"""
88+
trace_pos = image.shape[0] / 2
89+
trace = np.ones_like(image[0]) * trace_pos
90+
return trace_pos, trace
91+
92+
93+
@dataclass(init=False, frozen=True)
94+
class Trace(BaseTrace):
95+
"""
96+
Basic tracing class that by default traces the middle of the image.
97+
98+
Parameters
99+
----------
100+
image : `~astropy.nddata.CCDData`
101+
Image to be traced
102+
"""
103+
def __init__(self, image):
104+
trace_pos, trace = self._default_trace_attrs(image)
105+
super().__init__(image, trace_pos, trace)
106+
107+
108+
@dataclass(init=False, frozen=True)
109+
class FlatTrace(BaseTrace):
83110
"""
84111
Trace that is constant along the axis being traced
85112
@@ -92,10 +119,11 @@ class FlatTrace(Trace):
92119
trace_pos : float
93120
Position of the trace
94121
"""
95-
trace_pos: float
96122

97-
def __post_init__(self):
98-
self.set_position(self.trace_pos)
123+
def __init__(self, image, trace_pos):
124+
_, trace = self._default_trace_attrs(image)
125+
super().__init__(image, trace_pos, trace)
126+
self.set_position(trace_pos)
99127

100128
def set_position(self, trace_pos):
101129
"""
@@ -106,13 +134,13 @@ def set_position(self, trace_pos):
106134
trace_pos : float
107135
Position of the trace
108136
"""
109-
self.trace_pos = trace_pos
110-
self.trace = np.ones_like(self.image[0]) * self.trace_pos
137+
object.__setattr__(self, '_trace_pos', trace_pos)
138+
object.__setattr__(self, '_trace', np.ones_like(self.image[0]) * trace_pos)
111139
self._bound_trace()
112140

113141

114-
@dataclass
115-
class ArrayTrace(Trace):
142+
@dataclass(init=False, frozen=True)
143+
class ArrayTrace(BaseTrace):
116144
"""
117145
Define a trace given an array of trace positions
118146
@@ -121,25 +149,27 @@ class ArrayTrace(Trace):
121149
trace : `numpy.ndarray`
122150
Array containing trace positions
123151
"""
124-
trace: np.ndarray
152+
def __init__(self, image, trace):
153+
trace_pos, _ = self._default_trace_attrs(image)
154+
super().__init__(image, trace_pos, trace)
125155

126-
def __post_init__(self):
127156
nx = self.image.shape[1]
128-
nt = len(self.trace)
157+
nt = len(trace)
129158
if nt != nx:
130159
if nt > nx:
131160
# truncate trace to fit image
132-
self.trace = self.trace[0:nx]
161+
trace = trace[0:nx]
133162
else:
134163
# assume trace starts at beginning of image and pad out trace to fit.
135164
# padding will be the last value of the trace, but will be masked out.
136-
padding = np.ma.MaskedArray(np.ones(nx - nt) * self.trace[-1], mask=True)
137-
self.trace = np.ma.hstack([self.trace, padding])
165+
padding = np.ma.MaskedArray(np.ones(nx - nt) * trace[-1], mask=True)
166+
trace = np.ma.hstack([trace, padding])
167+
object.__setattr__(self, '_trace', trace)
138168
self._bound_trace()
139169

140170

141-
@dataclass
142-
class KosmosTrace(Trace):
171+
@dataclass(init=False, frozen=True)
172+
class KosmosTrace(BaseTrace):
143173
"""
144174
Trace the spectrum aperture in an image.
145175
@@ -192,14 +222,25 @@ class KosmosTrace(Trace):
192222
4) add other interpolation modes besides spline, maybe via
193223
specutils.manipulation methods?
194224
"""
195-
bins: int = 20
196-
guess: float = None
197-
window: int = None
198-
peak_method: str = 'gaussian'
225+
bins: int
226+
guess: float
227+
window: int
228+
peak_method: str
199229
_crossdisp_axis = 0
200230
_disp_axis = 1
201231

202-
def __post_init__(self):
232+
def _process_init_kwargs(self, **kwargs):
233+
for attr, value in kwargs.items():
234+
object.__setattr__(self, attr, value)
235+
236+
def __init__(self, image, bins=20, guess=None, window=None, peak_method='gaussian'):
237+
# This method will assign the user supplied value (or default) to the attrs:
238+
self._process_init_kwargs(
239+
bins=bins, guess=guess, window=window, peak_method=peak_method
240+
)
241+
trace_pos, trace = self._default_trace_attrs(image)
242+
super().__init__(image, trace_pos, trace)
243+
203244
# handle multiple image types and mask uncaught invalid values
204245
if isinstance(self.image, NDData):
205246
img = np.ma.masked_invalid(np.ma.masked_array(self.image.data,
@@ -223,7 +264,7 @@ def __post_init__(self):
223264

224265
if not isinstance(self.bins, int):
225266
warnings.warn('TRACE: Converting bins to int')
226-
self.bins = int(self.bins)
267+
object.__setattr__(self, 'bins', int(self.bins))
227268

228269
if self.bins < 4:
229270
raise ValueError('bins must be >= 4')
@@ -240,7 +281,7 @@ def __post_init__(self):
240281
"length of the image's spatial direction")
241282
elif self.window is not None and not isinstance(self.window, int):
242283
warnings.warn('TRACE: Converting window to int')
243-
self.window = int(self.window)
284+
object.__setattr__(self, 'window', int(self.window))
244285

245286
# set max peak location by user choice or wavelength with max avg flux
246287
ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis]
@@ -343,4 +384,4 @@ def __post_init__(self):
343384
warnings.warn("TRACE ERROR: No valid points found in trace")
344385
trace_y = np.tile(np.nan, len(x_bins))
345386

346-
self.trace = np.ma.masked_invalid(trace_y)
387+
object.__setattr__(self, '_trace', np.ma.masked_invalid(trace_y))

0 commit comments

Comments
 (0)