Skip to content

Commit c680ab6

Browse files
committed
add support for general 2D prior mask via mask_function argument
1 parent 032f06b commit c680ab6

File tree

7 files changed

+353
-159
lines changed

7 files changed

+353
-159
lines changed

docs/plot_gallery.html

Lines changed: 138 additions & 84 deletions
Large diffs are not rendered by default.

docs/plot_gallery.ipynb

Lines changed: 147 additions & 47 deletions
Large diffs are not rendered by default.

getdist/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__author__ = 'Antony Lewis'
2-
__version__ = "1.5.4"
2+
__version__ = "1.5.5"
33
__url__ = "https://getdist.readthedocs.io"
44

55
import os

getdist/analysis_defaults.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ smooth_scale_1D =-1
4040

4141
#0 is basic normalization correction
4242
#1 is linear boundary kernel (should get gradient correct)
43-
#2 is a higher order kernel, that also affects estimates way from the boundary (1D only)
43+
#2 is a higher order kernel, that also affects estimates away from the boundary (1D only)
4444
boundary_correction_order=1
4545

4646
#Correct for (over-smoothing) biases using multiplicative bias correction

getdist/densities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,17 +253,19 @@ class Density2D(GridDensity):
253253
You can call it like a :class:`~scipy:scipy.interpolate.RectBivariateSpline` object to get interpolated values.
254254
"""
255255

256-
def __init__(self, x, y, P=None, view_ranges=None):
256+
def __init__(self, x, y, P=None, view_ranges=None, mask=None):
257257
"""
258258
:param x: array of x values
259259
:param y: array of y values
260260
:param P: 2D array of density values at x, y
261261
:param view_ranges: optional ranges for viewing density
262+
:param mask: optional 2D boolean array for non-trivial mask
262263
"""
263264
self.x = x
264265
self.y = y
265266
self.axes = [y, x]
266267
self.view_ranges = view_ranges
268+
self.mask = mask
267269
self.spacing = (self.x[1] - self.x[0]) * (self.y[1] - self.y[0])
268270
self.setP(P)
269271

getdist/mcsamples.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,7 @@ def get1DDensityGridData(self, j, paramConfid=None, meanlikes=False, **kwargs):
16011601

16021602
return density1D
16031603

1604-
def _setEdgeMask2D(self, parx, pary, prior_mask, winw, alledge=False):
1604+
def _setEdgeMask2D(self, parx, pary, prior_mask, winw):
16051605
if parx.has_limits_bot:
16061606
prior_mask[:, winw] /= 2
16071607
prior_mask[:, :winw] = 0
@@ -1614,11 +1614,12 @@ def _setEdgeMask2D(self, parx, pary, prior_mask, winw, alledge=False):
16141614
if pary.has_limits_top:
16151615
prior_mask[-(winw + 1), :] /= 2
16161616
prior_mask[-winw:, :] = 0
1617-
if alledge:
1618-
prior_mask[:, :winw] = 0
1619-
prior_mask[:, -winw:] = 0
1620-
prior_mask[:winw:] = 0
1621-
prior_mask[-winw:, :] = 0
1617+
1618+
def _setAllEdgeMask2D(self, prior_mask, winw):
1619+
prior_mask[:, :winw] = 0
1620+
prior_mask[:, -winw:] = 0
1621+
prior_mask[:winw:] = 0
1622+
prior_mask[-winw:, :] = 0
16221623

16231624
def _getScaleForParam(self, par):
16241625
# Also ensures that the 1D limits are initialized
@@ -1655,7 +1656,8 @@ def get2DDensity(self, x, y, normalized=False, **kwargs):
16551656
return density
16561657

16571658
# noinspection PyUnboundLocalVariable
1658-
def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False, meanlikes=False, **kwargs):
1659+
def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False, meanlikes=False,
1660+
mask_function: callable = None, **kwargs):
16591661
"""
16601662
Low-level function to get 2D plot marginalized density and optional additional plot data.
16611663
@@ -1665,6 +1667,9 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
16651667
:param get_density: only get the 2D marginalized density, don't calculate confidence level members
16661668
:param meanlikes: calculate mean likelihoods as well as marginalized density
16671669
(returned as array in density.likes)
1670+
:param mask_function: optional function, mask_function(minx, miny, stepx, stepy, mask),
1671+
which which sets mask to zero for values of parameters that are excluded by prior. Note this is not
1672+
needed for standard min, max bounds aligned with axes, as they are handled by default.
16681673
:param kwargs: optional settings to override instance settings of the same name (see `analysis_settings`):
16691674
16701675
- **fine_bins_2D**
@@ -1689,7 +1694,7 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
16891694
mult_bias_correction_order = kwargs.get('mult_bias_correction_order', self.mult_bias_correction_order)
16901695
smooth_scale_2D = float(kwargs.get('smooth_scale_2D', self.smooth_scale_2D))
16911696

1692-
has_prior = parx.has_limits or pary.has_limits
1697+
has_prior = parx.has_limits or pary.has_limits or mask_function
16931698

16941699
corr = self.getCorrelationMatrix()[j2][j]
16951700
actual_corr = corr
@@ -1761,7 +1766,7 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
17611766
logging.debug('time 2D binning and bandwidth: %s ; bins: %s', time.time() - start, fine_bins_2D)
17621767
start = time.time()
17631768
cache = {}
1764-
convolvesize = xsize + 2 * winw + Win.shape[0]
1769+
convolvesize = xsize + 2 * winw + Win.shape[0] # larger than needed for selecting fft pixel count
17651770
bins2D = convolve2D(histbins, Win, 'same', largest_size=convolvesize, cache=cache)
17661771

17671772
if meanlikes:
@@ -1779,15 +1784,24 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
17791784
else:
17801785
bin2Dlikes = None
17811786

1787+
if has_prior and boundary_correction_order >= 0 or mult_bias_correction_order or mask_function:
1788+
prior_mask = np.ones((ysize + 2 * winw, xsize + 2 * winw))
1789+
if mask_function:
1790+
mask_function(xbinmin - winw * finewidthx, ybinmin - winw * finewidthy, finewidthx, finewidthy,
1791+
prior_mask)
1792+
bool_mask = prior_mask[winw:-winw, winw:-winw] < 1e-8
1793+
17821794
if has_prior and boundary_correction_order >= 0:
17831795
# Correct for edge effects
1784-
prior_mask = np.ones((ysize + 2 * winw, xsize + 2 * winw))
17851796
self._setEdgeMask2D(parx, pary, prior_mask, winw)
17861797
a00 = convolve2D(prior_mask, Win, 'valid', largest_size=convolvesize, cache=cache)
17871798
ix = a00 * bins2D > np.max(bins2D) * 1e-8
17881799
a00 = a00[ix]
17891800
normed = bins2D[ix] / a00
1790-
if boundary_correction_order == 1:
1801+
if boundary_correction_order == 0:
1802+
# simple boundary correction by normalization
1803+
bins2D[ix] = normed
1804+
elif boundary_correction_order == 1:
17911805
# linear boundary correction
17921806
indexes = np.arange(-winw, winw + 1)
17931807
y = np.empty(Win.shape)
@@ -1811,26 +1825,28 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
18111825
Ay = a01 * a20 - a10 * a11
18121826
corrected = (bins2D[ix] * A + xP * Ax + yP * Ay) / denom
18131827
bins2D[ix] = normed * np.exp(np.minimum(corrected / normed, 4) - 1)
1814-
elif boundary_correction_order == 0:
1815-
# simple boundary correction by normalization
1816-
bins2D[ix] = normed
18171828
else:
18181829
raise SettingError('unknown boundary_correction_order (expected 0 or 1)')
18191830

18201831
if mult_bias_correction_order:
1821-
prior_mask = np.ones((ysize + 2 * winw, xsize + 2 * winw))
1822-
self._setEdgeMask2D(parx, pary, prior_mask, winw, alledge=True)
1832+
self._setAllEdgeMask2D(prior_mask, winw)
18231833
a00 = convolve2D(prior_mask, Win, 'valid', largest_size=convolvesize, cache=cache, cache_args=[2])
18241834
for _ in range(mult_bias_correction_order):
18251835
box = histbins.copy()
18261836
ix2 = bins2D > np.max(bins2D) * 1e-8
18271837
box[ix2] /= bins2D[ix2]
18281838
bins2D *= convolve2D(box, Win, 'same', largest_size=convolvesize, cache=cache, cache_args=[2])
1829-
bins2D /= a00
1839+
if mask_function:
1840+
bins2D[~bool_mask] /= a00[~bool_mask]
1841+
else:
1842+
bins2D /= a00
1843+
1844+
if mask_function:
1845+
bins2D[bool_mask] = 0
18301846

18311847
x = np.linspace(xbinmin, xbinmax, xsize)
18321848
y = np.linspace(ybinmin, ybinmax, ysize)
1833-
density = Density2D(x, y, bins2D,
1849+
density = Density2D(x, y, bins2D, mask=None if not mask_function else np.asarray(bool_mask),
18341850
view_ranges=[(parx.range_min, parx.range_max), (pary.range_min, pary.range_max)])
18351851
density.normalize('max', in_place=True)
18361852
if get_density:

getdist/plots.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,8 @@ def _is_color_like(self, color):
10241024
return False
10251025

10261026
def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, cols=None, contour_levels=None,
1027-
add_legend_proxy=True, param_pair=None, density=None, alpha=None, ax=None, **kwargs):
1027+
add_legend_proxy=True, param_pair=None, density=None, alpha=None, ax=None,
1028+
mask_function: callable = None, **kwargs):
10281029
"""
10291030
Low-level function to add 2D contours to plot for samples with given root name and parameters
10301031
@@ -1043,6 +1044,9 @@ def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, col
10431044
:param alpha: alpha for the contours added
10441045
:param ax: optional :class:`~matplotlib:matplotlib.axes.Axes` instance (or y,x subplot coordinate)
10451046
to add to (defaults to current plot or the first/main plot if none)
1047+
:param mask_function: optional function, mask_function(minx, miny, stepx, stepy, mask),
1048+
which which sets mask to zero for values of parameter name that are excluded by prior.
1049+
See the example in the plot gallery.
10461050
:param kwargs: optional keyword arguments:
10471051
10481052
- **filled**: True to make filled contours
@@ -1055,7 +1059,13 @@ def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, col
10551059
if density is None:
10561060
param1, param2 = self.get_param_array(root, param_pair or [param1, param2])
10571061
ax.getdist_params = (param1, param2)
1058-
if isinstance(root, MixtureND):
1062+
if mask_function is not None:
1063+
samples = self.samples_for_root(root)
1064+
density = samples.get2DDensityGridData(param1.name, param2.name,
1065+
mask_function=mask_function,
1066+
num_plot_contours=self.settings.num_plot_contours,
1067+
meanlikes=self.settings.shade_meanlikes)
1068+
elif isinstance(root, MixtureND):
10591069
density = root.marginalizedMixture(params=[param1, param2]).density2D()
10601070
else:
10611071
density = self.sample_analyser.get_density_grid(root, param1, param2,
@@ -1086,6 +1096,7 @@ def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, col
10861096
def clean_args(_args):
10871097
return {k: v for k, v in _args.items() if k not in ('color', 'ls', 'lw')}
10881098

1099+
z = density.P if density.mask is None else np.ma.masked_where(density.mask, density.P)
10891100
if kwargs.get('filled'):
10901101
if cols is None:
10911102
color = kwargs.get('color')
@@ -1098,13 +1109,13 @@ def clean_args(_args):
10981109
else:
10991110
cols = color
11001111
levels = sorted(np.append([density.P.max() + 1], contour_levels))
1101-
cs = ax.contourf(density.x, density.y, density.P, levels, colors=cols, alpha=alpha, **clean_args(kwargs))
1112+
cs = ax.contourf(density.x, density.y, z, levels, colors=cols, alpha=alpha, **clean_args(kwargs))
11021113

11031114
fc = tuple(cs.to_rgba(cs.cvalues[-1], cs.alpha))
11041115
if proxy_ix >= 0:
11051116
self.contours_added[proxy_ix] = (
11061117
matplotlib.patches.Rectangle((0, 0), 1, 1, fc=fc))
1107-
ax.contour(density.x, density.y, density.P, levels[:1], colors=(fc,),
1118+
ax.contour(density.x, density.y, z, levels[:1], colors=(fc,),
11081119
linewidths=self._scaled_linewidth(self.settings.linewidth_contour
11091120
if kwargs.get('lw') is None else kwargs['lw']),
11101121
linestyles=kwargs.get('ls'),
@@ -1116,7 +1127,7 @@ def clean_args(_args):
11161127
lws = args['lw'] # note linewidth_contour is only used for filled contours
11171128
kwargs = self._get_plot_args(plotno, **kwargs)
11181129
kwargs['alpha'] = alpha
1119-
cs = ax.contour(density.x, density.y, density.P, sorted(contour_levels), colors=cols, linestyles=linestyles,
1130+
cs = ax.contour(density.x, density.y, z, sorted(contour_levels), colors=cols, linestyles=linestyles,
11201131
linewidths=lws, **clean_args(kwargs))
11211132
dashes = args.get('dashes')
11221133
if dashes:
@@ -1658,14 +1669,15 @@ def plot_1d(self, roots, param, marker=None, marker_color=None, label_right=Fals
16581669
self.finish_plot()
16591670

16601671
def plot_2d(self, roots, param1=None, param2=None, param_pair=None, shaded=False,
1661-
add_legend_proxy=True, line_offset=0, proxy_root_exclude=(), ax=None, **kwargs):
1672+
add_legend_proxy=True, line_offset=0, proxy_root_exclude=(), ax=None,
1673+
mask_function: callable = None, **kwargs):
16621674
"""
16631675
Create a single 2D line, contour or filled plot.
16641676
16651677
:param roots: root name or :class:`~.mcsamples.MCSamples` instance (or list of any of either of these) for
16661678
the samples to plot
16671679
:param param1: x parameter name
1668-
:param param2: y parameter name
1680+
:param param2: y parameter name
16691681
:param param_pair: An [x,y] pair of params; can be set instead of param1 and param2
16701682
:param shaded: True or integer if plot should be a shaded density plot, where the integer specifies
16711683
the index of which contour is shaded (first samples shaded if True provided instead
@@ -1675,6 +1687,15 @@ def plot_2d(self, roots, param1=None, param2=None, param_pair=None, shaded=False
16751687
:param proxy_root_exclude: any root names not to include when adding to the legend proxy
16761688
:param ax: optional :class:`~matplotlib:matplotlib.axes.Axes` instance (or y,x subplot coordinate)
16771689
to add to (defaults to current plot or the first/main plot if none)
1690+
:param mask_function: Function that defines regions in the 2D parameter space to exclude from the plot.
1691+
Must have signature mask_function(minx, miny, stepx, stepy, mask), where:
1692+
- minx, miny: minimum values of x and y parameters
1693+
- stepx, stepy: step sizes in x and y directions
1694+
- mask: 2D boolean numpy array (modified in-place)
1695+
The function should set mask values to 0 where points should be excluded by the prior.
1696+
Useful for implementing non-rectangular prior boundaries not aligned with parameter axes,
1697+
- see the example in the plot gallery.
1698+
Note it should not include simple axis-aligned range priors that are accounted for automatically.
16781699
:param kwargs: additional optional arguments:
16791700
16801701
* **filled**: True for filled contours
@@ -1711,6 +1732,7 @@ def plot_2d(self, roots, param1=None, param2=None, param_pair=None, shaded=False
17111732
contour_args = self._make_contour_args(len(roots), **kwargs)
17121733
for i, root in enumerate(roots):
17131734
res = self.add_2d_contours(root, param_pair[0], param_pair[1], line_offset + i, of=len(roots), ax=ax,
1735+
mask_function=mask_function,
17141736
add_legend_proxy=add_legend_proxy and root not in proxy_root_exclude,
17151737
**contour_args[i])
17161738
xbounds, ybounds = self._update_limits(res, xbounds, ybounds)

0 commit comments

Comments
 (0)