Skip to content

Commit 907457e

Browse files
committed
Implementing BaseTemplate to manage Trace attr mutability
1 parent 67b6b8b commit 907457e

File tree

3 files changed

+112
-54
lines changed

3 files changed

+112
-54
lines changed

specreduce/background.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from astropy import units as u
99

1010
from specreduce.extract import _ap_weight_image, _to_spectrum1d_pixels
11-
from specreduce.tracing import Trace, FlatTrace
11+
from specreduce.tracing import FlatTrace, BaseTrace
1212

1313
__all__ = ['Background']
1414

@@ -77,7 +77,7 @@ def __post_init__(self):
7777
cross-dispersion axis
7878
"""
7979
def _to_trace(trace):
80-
if not isinstance(trace, Trace):
80+
if not isinstance(trace, BaseTrace):
8181
trace = FlatTrace(self.image, trace)
8282

8383
# TODO: this check can be removed if/when implemented as a check in FlatTrace
@@ -93,7 +93,7 @@ def _to_trace(trace):
9393
self.bkg_array = np.zeros(self.image.shape[self.disp_axis])
9494
return
9595

96-
if isinstance(self.traces, Trace):
96+
if isinstance(self.traces, BaseTrace):
9797
self.traces = [self.traces]
9898

9999
bkg_wimage = np.zeros_like(self.image, dtype=np.float64)

specreduce/extract.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from astropy.nddata import NDData
1111

1212
from specreduce.core import SpecreduceOperation
13-
from specreduce.tracing import Trace, FlatTrace
13+
from specreduce.tracing import FlatTrace, BaseTrace
1414
from specutils import Spectrum1D
1515

1616
__all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract']
@@ -88,7 +88,7 @@ def _ap_weight_image(trace, width, disp_axis, crossdisp_axis, image_shape):
8888
8989
Parameters
9090
----------
91-
trace : `~specreduce.tracing.Trace`, required
91+
trace : `~specreduce.tracing.BaseTrace`, required
9292
trace object
9393
width : float, required
9494
width of extraction aperture in pixels
@@ -139,7 +139,7 @@ class BoxcarExtract(SpecreduceOperation):
139139
----------
140140
image : nddata-compatible image
141141
image with 2-D spectral image data
142-
trace_object : Trace
142+
trace_object : BaseTrace
143143
trace object
144144
width : float
145145
width of extraction aperture in pixels
@@ -154,7 +154,7 @@ class BoxcarExtract(SpecreduceOperation):
154154
The extracted 1d spectrum expressed in DN and pixel units
155155
"""
156156
image: NDData
157-
trace_object: Trace
157+
trace_object: BaseTrace
158158
width: float = 5
159159
disp_axis: int = 1
160160
crossdisp_axis: int = 0
@@ -173,7 +173,7 @@ def __call__(self, image=None, trace_object=None, width=None,
173173
----------
174174
image : nddata-compatible image
175175
image with 2-D spectral image data
176-
trace_object : Trace
176+
trace_object : BaseTrace
177177
trace object
178178
width : float
179179
width of extraction aperture in pixels [default: 5]
@@ -230,7 +230,7 @@ class HorneExtract(SpecreduceOperation):
230230
NDData object must specify uncertainty and a mask. An array
231231
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
232232
233-
trace_object : `~specreduce.tracing.Trace`, required
233+
trace_object : `~specreduce.tracing.BaseTrace`, required
234234
The associated 1D trace object created for the 2D image.
235235
236236
disp_axis : int, optional
@@ -264,7 +264,7 @@ class HorneExtract(SpecreduceOperation):
264264
265265
"""
266266
image: NDData
267-
trace_object: Trace
267+
trace_object: BaseTrace
268268
bkgrd_prof: Model = field(default=models.Polynomial1D(2))
269269
variance: np.ndarray = field(default=None)
270270
mask: np.ndarray = field(default=None)
@@ -293,7 +293,7 @@ def __call__(self, image=None, trace_object=None,
293293
NDData object must specify uncertainty and a mask. An array
294294
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
295295
296-
trace_object : `~specreduce.tracing.Trace`, required
296+
trace_object : `~specreduce.tracing.BaseTrace`, required
297297
The associated 1D trace object created for the 2D image.
298298
299299
disp_axis : int, optional

specreduce/tracing.py

Lines changed: 101 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

33
from copy import deepcopy
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
import warnings
66

77
from astropy.modeling import fitting, models
@@ -10,29 +10,22 @@
1010
from scipy.interpolate import UnivariateSpline
1111
import numpy as np
1212

13-
__all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
13+
__all__ = ['BaseTrace', 'Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
1414

1515

16-
@dataclass
17-
class Trace:
16+
@dataclass(frozen=True)
17+
class BaseTrace:
1818
"""
19-
Basic tracing class that by default traces the middle of the image.
20-
21-
Parameters
22-
----------
23-
image : `~astropy.nddata.CCDData`
24-
Image to be traced
25-
26-
Properties
27-
----------
28-
shape : tuple
29-
Shape of the array describing the trace
19+
A dataclass common to all Trace objects.
3020
"""
3121
image: CCDData
22+
_trace_pos: (float, np.ndarray) = field(repr=False)
23+
_trace: np.ndarray = field(repr=False)
3224

3325
def __post_init__(self):
34-
self.trace_pos = self.image.shape[0] / 2
35-
self.trace = np.ones_like(self.image[0]) * self.trace_pos
26+
# this class only exists to catch __post_init__ calls in its
27+
# subclasses, so that super().__post_init__ calls work correctly.
28+
pass
3629

3730
def __getitem__(self, i):
3831
return self.trace[i]
@@ -59,7 +52,7 @@ def _bound_trace(self):
5952
Mask trace positions that are outside the upper/lower bounds of the image.
6053
"""
6154
ny = self.image.shape[0]
62-
self.trace = np.ma.masked_outside(self.trace, 0, ny-1)
55+
object.__setattr__(self, '_trace', np.ma.masked_outside(self._trace, 0, ny - 1))
6356

6457
def __add__(self, delta):
6558
"""
@@ -77,9 +70,60 @@ def __sub__(self, delta):
7770
"""
7871
return self.__add__(-delta)
7972

73+
def shift(self, delta):
74+
"""
75+
Shift the trace by delta pixels perpendicular to the axis being traced
76+
77+
Parameters
78+
----------
79+
delta : float
80+
Shift to be applied to the trace
81+
"""
82+
# act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
83+
object.__setattr__(self, '_trace', np.asarray(self._trace.data) + delta)
84+
object.__setattr__(self, '_trace_pos', self._trace_pos + delta)
85+
self._bound_trace()
86+
87+
@property
88+
def shape(self):
89+
return self._trace.shape
90+
91+
@property
92+
def trace(self):
93+
return self._trace
94+
95+
@property
96+
def trace_pos(self):
97+
return self._trace_pos
98+
99+
@staticmethod
100+
def _default_trace_attrs(image):
101+
"""
102+
Compute a default trace position and trace array using only
103+
the image dimensions.
104+
"""
105+
trace_pos = image.shape[0] / 2
106+
trace = np.ones_like(image[0]) * trace_pos
107+
return trace_pos, trace
108+
109+
110+
@dataclass(init=False, frozen=True)
111+
class Trace(BaseTrace):
112+
"""
113+
Basic tracing class that by default traces the middle of the image.
114+
115+
Parameters
116+
----------
117+
image : `~astropy.nddata.CCDData`
118+
Image to be traced
119+
"""
120+
def __init__(self, image):
121+
trace_pos, trace = self._default_trace_attrs(image)
122+
super().__init__(image, trace_pos, trace)
123+
80124

81-
@dataclass
82-
class FlatTrace(Trace):
125+
@dataclass(init=False, frozen=True)
126+
class FlatTrace(BaseTrace):
83127
"""
84128
Trace that is constant along the axis being traced
85129
@@ -92,10 +136,11 @@ class FlatTrace(Trace):
92136
trace_pos : float
93137
Position of the trace
94138
"""
95-
trace_pos: float
96139

97-
def __post_init__(self):
98-
self.set_position(self.trace_pos)
140+
def __init__(self, image, trace_pos):
141+
_, trace = self._default_trace_attrs(image)
142+
super().__init__(image, trace_pos, trace)
143+
self.set_position(trace_pos)
99144

100145
def set_position(self, trace_pos):
101146
"""
@@ -106,13 +151,13 @@ def set_position(self, trace_pos):
106151
trace_pos : float
107152
Position of the trace
108153
"""
109-
self.trace_pos = trace_pos
110-
self.trace = np.ones_like(self.image[0]) * self.trace_pos
154+
object.__setattr__(self, '_trace_pos', trace_pos)
155+
object.__setattr__(self, '_trace', np.ones_like(self.image[0]) * trace_pos)
111156
self._bound_trace()
112157

113158

114-
@dataclass
115-
class ArrayTrace(Trace):
159+
@dataclass(init=False, frozen=True)
160+
class ArrayTrace(BaseTrace):
116161
"""
117162
Define a trace given an array of trace positions
118163
@@ -121,25 +166,27 @@ class ArrayTrace(Trace):
121166
trace : `numpy.ndarray`
122167
Array containing trace positions
123168
"""
124-
trace: np.ndarray
169+
def __init__(self, image, trace):
170+
trace_pos, _ = self._default_trace_attrs(image)
171+
super().__init__(image, trace_pos, trace)
125172

126-
def __post_init__(self):
127173
nx = self.image.shape[1]
128-
nt = len(self.trace)
174+
nt = len(trace)
129175
if nt != nx:
130176
if nt > nx:
131177
# truncate trace to fit image
132-
self.trace = self.trace[0:nx]
178+
trace = trace[0:nx]
133179
else:
134180
# assume trace starts at beginning of image and pad out trace to fit.
135181
# padding will be the last value of the trace, but will be masked out.
136-
padding = np.ma.MaskedArray(np.ones(nx - nt) * self.trace[-1], mask=True)
137-
self.trace = np.ma.hstack([self.trace, padding])
182+
padding = np.ma.MaskedArray(np.ones(nx - nt) * trace[-1], mask=True)
183+
trace = np.ma.hstack([trace, padding])
184+
object.__setattr__(self, '_trace', trace)
138185
self._bound_trace()
139186

140187

141-
@dataclass
142-
class KosmosTrace(Trace):
188+
@dataclass(init=False, frozen=True)
189+
class KosmosTrace(BaseTrace):
143190
"""
144191
Trace the spectrum aperture in an image.
145192
@@ -192,14 +239,25 @@ class KosmosTrace(Trace):
192239
4) add other interpolation modes besides spline, maybe via
193240
specutils.manipulation methods?
194241
"""
195-
bins: int = 20
196-
guess: float = None
197-
window: int = None
198-
peak_method: str = 'gaussian'
242+
bins: int
243+
guess: float
244+
window: int
245+
peak_method: str
199246
_crossdisp_axis = 0
200247
_disp_axis = 1
201248

202-
def __post_init__(self):
249+
def _process_init_kwargs(self, **kwargs):
250+
for attr, value in kwargs.items():
251+
object.__setattr__(self, attr, value)
252+
253+
def __init__(self, image, bins=20, guess=None, window=None, peak_method='gaussian'):
254+
# This method will assign the user supplied value (or default) to the attrs:
255+
self._process_init_kwargs(
256+
bins=bins, guess=guess, window=window, peak_method=peak_method
257+
)
258+
trace_pos, trace = self._default_trace_attrs(image)
259+
super().__init__(image, trace_pos, trace)
260+
203261
# handle multiple image types and mask uncaught invalid values
204262
if isinstance(self.image, NDData):
205263
img = np.ma.masked_invalid(np.ma.masked_array(self.image.data,
@@ -223,7 +281,7 @@ def __post_init__(self):
223281

224282
if not isinstance(self.bins, int):
225283
warnings.warn('TRACE: Converting bins to int')
226-
self.bins = int(self.bins)
284+
object.__setattr__(self, 'bins', int(self.bins))
227285

228286
if self.bins < 4:
229287
raise ValueError('bins must be >= 4')
@@ -240,7 +298,7 @@ def __post_init__(self):
240298
"length of the image's spatial direction")
241299
elif self.window is not None and not isinstance(self.window, int):
242300
warnings.warn('TRACE: Converting window to int')
243-
self.window = int(self.window)
301+
object.__setattr__(self, 'window', int(self.window))
244302

245303
# set max peak location by user choice or wavelength with max avg flux
246304
ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis]
@@ -343,4 +401,4 @@ def __post_init__(self):
343401
warnings.warn("TRACE ERROR: No valid points found in trace")
344402
trace_y = np.tile(np.nan, len(x_bins))
345403

346-
self.trace = np.ma.masked_invalid(trace_y)
404+
object.__setattr__(self, '_trace', np.ma.masked_invalid(trace_y))

0 commit comments

Comments
 (0)