Skip to content

Commit 799ff49

Browse files
authored
Modify masking and NaN handling in HorneExtract (#163)
* Update NaN handling and masking in HorneExtract * Removed masking so NaNs propagate into 1D spectra
1 parent 7043f78 commit 799ff49

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

specreduce/extract.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def _parse_image(self, image,
359359
elif mask is not None:
360360
pass
361361
else:
362-
mask = ~np.isfinite(img)
362+
# if user provides no mask at all, don't mask anywhere
363+
mask = np.zeros_like(img)
363364

364365
if img.shape != mask.shape:
365366
raise ValueError('image and mask shapes must match.')
@@ -484,13 +485,15 @@ def __call__(self, image=None, trace_object=None,
484485
# parse image and replace optional arguments with updated values
485486
self.image = self._parse_image(image, variance, mask, unit, disp_axis)
486487
variance = self.image.uncertainty.array
488+
mask = self.image.mask
487489
unit = self.image.unit
488490

489-
# mask any previously uncaught invalid values
490-
or_mask = np.logical_or(mask,
491+
img = np.ma.masked_array(self.image.data, mask=mask)
492+
493+
# create separate mask including any previously uncaught non-finite
494+
# values for purposes of calculating fit
495+
or_mask = np.logical_or(img.mask,
491496
~np.isfinite(self.image.data))
492-
img = np.ma.masked_array(self.image.data, or_mask)
493-
mask = img.mask
494497

495498
# If the trace is not flat, shift the rows in each column
496499
# so the image is aligned along the trace:
@@ -510,26 +513,43 @@ def __call__(self, image=None, trace_object=None,
510513
)
511514

512515
# co-add signal in each image column
513-
ncols = img.shape[crossdisp_axis]
514-
xd_pixels = np.arange(ncols) # y plot dir / x spec dir
515-
coadd = img.sum(axis=disp_axis) / ncols
516+
nrows = img.shape[crossdisp_axis]
517+
xd_pixels = np.arange(nrows) # counted in y dir on plot (or x in spec)
518+
519+
row_mask = np.logical_or.reduce(or_mask, axis=disp_axis)
520+
coadd = np.ma.masked_array(np.sum(img, axis=disp_axis) / nrows,
521+
mask=row_mask)
522+
# (mask rows with non-finite sums for fit to work later on)
516523

517-
# fit source profile, using Gaussian model as a template
524+
# fit source profile to brightest row, using Gaussian model as template
518525
# NOTE: could add argument for users to provide their own model
519526
gauss_prof = models.Gaussian1D(amplitude=coadd.max(),
520527
mean=coadd.argmax(), stddev=2)
521528

522-
# Fit extraction kernel to column with combined gaussian/bkgrd model
529+
# Fit extraction kernel to column's finite values with combined model
530+
# (must exclude masked indices manually; LevMarLSQFitter does not)
523531
ext_prof = gauss_prof + bkgrd_prof
524532
fitter = fitting.LevMarLSQFitter()
525-
fit_ext_kernel = fitter(ext_prof, xd_pixels, coadd)
533+
fit_ext_kernel = fitter(ext_prof,
534+
xd_pixels[~row_mask], coadd[~row_mask])
526535

527-
# use compound model to fit a kernel to each image column
536+
# use compound model to fit a kernel to each fully finite image column
528537
# NOTE: infers Gaussian1D source profile; needs generalization for others
538+
col_mask = np.logical_or.reduce(or_mask, axis=crossdisp_axis)
539+
nonf_col = [np.nan] * img.shape[crossdisp_axis]
540+
529541
kernel_vals = []
530542
norms = []
531543
for col_pix in range(img.shape[disp_axis]):
532-
# set gaussian model's mean as column's corresponding trace value
544+
# for now, skip columns with any non-finite values
545+
# NOTE: fit and other kernel operations should support masking again
546+
# once a fix is in for renormalizing columns with non-finite values
547+
if col_mask[col_pix]:
548+
kernel_vals.append(nonf_col)
549+
norms.append(np.nan)
550+
continue
551+
552+
# else, set compound model's mean to column's matching trace value
533553
fit_ext_kernel.mean_0 = mean_init_guess[col_pix]
534554

535555
# NOTE: support for variable FWHMs forthcoming and would be here
@@ -543,15 +563,15 @@ def __call__(self, image=None, trace_object=None,
543563
* fit_ext_kernel.stddev_0 * np.sqrt(2*np.pi))
544564

545565
# transform fit-specific information
546-
kernel_vals = np.array(kernel_vals).T
566+
kernel_vals = np.vstack(kernel_vals).T
547567
norms = np.array(norms)
548568

549-
# calculate kernel normalization, masking NaNs
550-
g_x = np.ma.sum(kernel_vals**2 / variance, axis=crossdisp_axis)
569+
# calculate kernel normalization
570+
g_x = np.sum(kernel_vals**2 / variance, axis=crossdisp_axis)
551571

552572
# sum by column weights
553-
weighted_img = np.ma.divide(img * kernel_vals, variance)
554-
result = np.ma.sum(weighted_img, axis=crossdisp_axis) / g_x
573+
weighted_img = np.divide(img * kernel_vals, variance)
574+
result = np.sum(weighted_img, axis=crossdisp_axis) / g_x
555575

556576
# multiply kernel normalization into the extracted signal
557577
extraction = result * norms

specreduce/tests/test_extract.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import astropy.units as u
55
from astropy.nddata import CCDData, VarianceUncertainty, UnknownUncertainty
66
from astropy.tests.helper import assert_quantity_allclose
7-
from astropy.utils.exceptions import AstropyUserWarning
87

98
from specreduce.extract import (
109
BoxcarExtract, HorneExtract, OptimalExtract, _align_along_trace
@@ -128,6 +127,8 @@ def test_horne_image_validation():
128127
== np.arange(image.shape[extract.disp_axis]) * u.pix)
129128

130129

130+
# ignore Astropy warning for extractions that aren't best fit with a Gaussian:
131+
@pytest.mark.filterwarnings("ignore:The fit may be unsuccessful")
131132
def test_horne_variance_errors():
132133
trace = FlatTrace(image, 3.0)
133134

@@ -155,6 +156,7 @@ def test_horne_variance_errors():
155156
mask=image.mask, unit=u.Jy)
156157

157158

159+
@pytest.mark.filterwarnings("ignore:The fit may be unsuccessful")
158160
def test_horne_non_flat_trace():
159161
# create a synthetic "2D spectrum" and its non-flat trace
160162
n_rows, n_cols = (10, 50)
@@ -181,19 +183,17 @@ def test_horne_non_flat_trace():
181183
# ensure that mask is correctly unrolled back to its original alignment:
182184
np.testing.assert_allclose(unrolled, original)
183185

184-
# These synthetic extractions don't fit well with a Gaussian, so will pass warning:
185-
with pytest.warns(AstropyUserWarning, match="The fit may be unsuccessful"):
186-
# Extract the spectrum from the non-flat image+trace
187-
extract_non_flat = HorneExtract(
188-
rolled, ArrayTrace(rolled, exact_trace),
189-
variance=err, mask=mask, unit=u.Jy
190-
)()
191-
192-
# Also extract the spectrum from the image after alignment with a flat trace
193-
extract_flat = HorneExtract(
194-
unrolled, FlatTrace(unrolled, n_rows // 2),
195-
variance=err, mask=mask, unit=u.Jy
196-
)()
186+
# Extract the spectrum from the non-flat image+trace
187+
extract_non_flat = HorneExtract(
188+
rolled, ArrayTrace(rolled, exact_trace),
189+
variance=err, mask=mask, unit=u.Jy
190+
)()
191+
192+
# Also extract the spectrum from the image after alignment with a flat trace
193+
extract_flat = HorneExtract(
194+
unrolled, FlatTrace(unrolled, n_rows // 2),
195+
variance=err, mask=mask, unit=u.Jy
196+
)()
197197

198198
# ensure both extractions are equivalent:
199199
assert_quantity_allclose(extract_non_flat.flux, extract_flat.flux)

0 commit comments

Comments
 (0)