Skip to content

Commit 7a56c88

Browse files
code cleanup
new mcmc move strategy model.flux unit fix. RS: hardcut Gamma
1 parent 15e7b73 commit 7a56c88

File tree

12 files changed

+122
-90
lines changed

12 files changed

+122
-90
lines changed

VegasAfterglow/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def validate_parameters(self, param_defs: Sequence[ParamDef]) -> None:
159159

160160
# Check reverse shock parameters
161161
if self.config.rvs_shock:
162-
rvs_required = {"p_r", "eps_e_r", "eps_B_r"}
162+
rvs_required = {"p_r", "eps_e_r", "eps_B_r", "tau"}
163163
missing_rvs = rvs_required - param_names
164164
if missing_rvs:
165165
missing_params.extend(

VegasAfterglow/sampler.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import emcee
99
import numpy as np
10-
from emcee.moves import DEMove, DESnookerMove
1110

1211
from .types import FitResult, ModelParams, ObsData, Setups, VegasMC
1312

@@ -90,22 +89,23 @@ def run(
9089
"""
9190
Run coarse MCMC + optional stretch-move refinement at higher resolution.
9291
"""
93-
# 1) configure coarse grid
92+
9493
cfg = self._make_cfg(base_cfg, *resolution)
9594

96-
# 2) prepare log-prob
9795
log_prob = _log_prob(
9896
data, cfg, self.to_params, self.pl, self.pu, self.model_cls
9997
)
10098

101-
# 3) initialize walker positions
10299
spread = 0.05 * (self.pu - self.pl)
103100
pos = self.init + spread * np.random.randn(self.nwalkers, self.ndim)
104101
pos = np.clip(pos, self.pl + 1e-8, self.pu - 1e-8)
105102

106-
# 4) default moves
107103
if moves is None:
108-
moves = [(DEMove(), 0.8), (DESnookerMove(), 0.2)]
104+
moves = [
105+
(emcee.moves.StretchMove(a=2.0), 0.6),
106+
(emcee.moves.DEMove(gamma=None), 0.3),
107+
(emcee.moves.DESnookerMove(gamma=None), 0.1),
108+
]
109109

110110
logger.info(
111111
"Running coarse MCMC at resolution %s for %d steps",
@@ -118,18 +118,14 @@ def run(
118118
)
119119
sampler.run_mcmc(pos, total_steps, progress=True)
120120

121-
# 5) extract & filter
122121
burn = int(burn_frac * total_steps)
123122
chain = sampler.get_chain(discard=burn, thin=thin)
124123
logp = sampler.get_log_prob(discard=burn, thin=thin)
125124
chain, logp, _ = self._filter_bad_walkers(chain, logp)
126125

127-
# 6) flatten & find top k fits
128126
flat_chain = chain.reshape(-1, self.ndim)
129127
flat_logp = logp.reshape(-1)
130128

131-
# Find top k unique parameter combinations
132-
# Sort by log prob (descending)
133129
sorted_idx = np.argsort(flat_logp)[::-1]
134130

135131
# Round parameters to avoid floating point precision issues
@@ -151,7 +147,6 @@ def run(
151147
top_k_log_probs[-1],
152148
)
153149

154-
# 8) return FitResult
155150
return FitResult(
156151
samples=chain,
157152
log_probs=logp,

docs/source/examples.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Suppose you want to calculate the flux at specific time-frequency pairs (t_i, nu
9797
# Define observing frequencies (must be the same length as times)
9898
bands = np.logspace(9,17, 200)
9999
100-
results = model.flux_density(times, bands) #times array could be random order
100+
results = model.flux_density(times, bands) #times array must be in ascending order
101101
102102
# the returned results is a FluxDict object with arrays of the same shape as the input times and bands.
103103
@@ -222,7 +222,7 @@ User-Defined Medium
222222
# Define a custom density profile function
223223
def density(phi, theta, r):# r in cm, phi and theta in radians
224224
return mp # n_ism = 1 cm^-3
225-
#return whatever density profile (cm^-3) you want as a function of phi, theta, and r
225+
#return whatever density profile (g*cm^-3) you want as a function of phi, theta, and r
226226
227227
# Create a user-defined medium
228228
medium = Medium(rho=density)

docs/source/mcmc_fitting.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ For large datasets or densely sampled observations, using all available data poi
9393
flux_err_dense = 0.1 * flux_dense
9494
9595
# Subsample using logarithmic screening
96-
# This selects ~50-100 representative points across 5 decades in time
96+
# This selects ~5*5=25 representative points across 5 decades in time
9797
indices = ObsData.logscale_screen(t_dense, num_order=5)
9898
9999
# Add only the selected subset

include/jet.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ namespace math {
342342
if (theta <= theta_c) {
343343
return 0.;
344344
} else {
345-
return height / (1 + fast_pow(theta / theta_c, k));
345+
return height * fast_pow(theta / theta_c, -k);
346346
}
347347
};
348348
}
@@ -352,7 +352,7 @@ namespace math {
352352
if (theta <= theta_c) {
353353
return 1.;
354354
} else {
355-
return height / (1 + fast_pow(theta / theta_c, k)) + 1;
355+
return height * fast_pow(theta / theta_c, -k) + 1;
356356
}
357357
};
358358
}
@@ -362,7 +362,7 @@ namespace math {
362362
if (theta <= theta_c) {
363363
return height_c;
364364
} else {
365-
return height_w / (1 + fast_pow(theta / theta_c, k));
365+
return height_w * fast_pow(theta / theta_c, -k);
366366
}
367367
};
368368
}
@@ -372,7 +372,7 @@ namespace math {
372372
if (theta <= theta_c) {
373373
return height_c + 1;
374374
} else {
375-
return height_w / (1 + fast_pow(theta / theta_c, k)) + 1;
375+
return height_w * fast_pow(theta / theta_c, -k) + 1;
376376
}
377377
};
378378
}

include/observer.h

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class Observer {
7070
* @return Array of specific flux values at each observation time
7171
* <!-- ************************************************************************************** -->
7272
*/
73-
template <typename... PhotonGrid>
74-
Array specific_flux(Array const& t_obs, Real nu_obs, PhotonGrid&... photons);
73+
template <typename PhotonGrid>
74+
Array specific_flux(Array const& t_obs, Real nu_obs, PhotonGrid& photons);
7575

7676
/**
7777
* <!-- ************************************************************************************** -->
@@ -86,8 +86,8 @@ class Observer {
8686
* relativistic beaming and cosmological effects.
8787
* <!-- ************************************************************************************** -->
8888
*/
89-
template <typename... PhotonGrid>
90-
MeshGrid specific_flux(Array const& t_obs, Array const& nu_obs, PhotonGrid&... photons);
89+
template <typename PhotonGrid>
90+
MeshGrid specific_flux(Array const& t_obs, Array const& nu_obs, PhotonGrid& photons);
9191

9292
/**
9393
* <!-- ************************************************************************************** -->
@@ -99,8 +99,8 @@ class Observer {
9999
* @return Array of specific flux values at each observation time for a single observed frequency
100100
* <!-- ************************************************************************************** -->
101101
*/
102-
template <typename... PhotonGrid>
103-
Array specific_flux_series(Array const& t_obs, Array const& nu_obs, PhotonGrid&... photons);
102+
template <typename PhotonGrid>
103+
Array specific_flux_series(Array const& t_obs, Array const& nu_obs, PhotonGrid& photons);
104104

105105
/**
106106
* <!-- ************************************************************************************** -->
@@ -115,8 +115,8 @@ class Observer {
115115
* @return Array of integrated flux values at each observation time
116116
* <!-- ************************************************************************************** -->
117117
*/
118-
template <typename... PhotonGrid>
119-
Array flux(Array const& t_obs, Array const& band_freq, PhotonGrid&... photons);
118+
template <typename PhotonGrid>
119+
Array flux(Array const& t_obs, Array const& band_freq, PhotonGrid& photons);
120120

121121
/**
122122
* <!-- ************************************************************************************** -->
@@ -130,8 +130,8 @@ class Observer {
130130
* @return 2D grid of spectra (frequency × time)
131131
* <!-- ************************************************************************************** -->
132132
*/
133-
template <typename... PhotonGrid>
134-
MeshGrid spectra(Array const& freqs, Array const& t_obs, PhotonGrid&... photons);
133+
template <typename PhotonGrid>
134+
MeshGrid spectra(Array const& freqs, Array const& t_obs, PhotonGrid& photons);
135135

136136
/**
137137
* <!-- ************************************************************************************** -->
@@ -145,8 +145,8 @@ class Observer {
145145
* @return Array containing the spectrum at the given time
146146
* <!-- ************************************************************************************** -->
147147
*/
148-
template <typename... PhotonGrid>
149-
Array spectrum(Array const& freqs, Real t_obs, PhotonGrid&... photons);
148+
template <typename PhotonGrid>
149+
Array spectrum(Array const& freqs, Real t_obs, PhotonGrid& photons);
150150

151151
/**
152152
* <!-- ************************************************************************************** -->
@@ -253,9 +253,8 @@ class Observer {
253253
* @return True if both lower and upper boundaries are valid for interpolation, false otherwise
254254
* <!-- ************************************************************************************** -->
255255
*/
256-
template <typename... PhotonGrid>
257-
bool set_boundaries(InterpState& state, size_t i, size_t j, size_t k, Real log2_nu,
258-
PhotonGrid&... photons) noexcept;
256+
template <typename PhotonGrid>
257+
bool set_boundaries(InterpState& state, size_t i, size_t j, size_t k, Real log2_nu, PhotonGrid& photons) noexcept;
259258
};
260259

261260
//========================================================================================================
@@ -277,9 +276,9 @@ inline void iterate_to(Real value, Array const& arr, size_t& it) noexcept {
277276
}
278277
}
279278

280-
template <typename... PhotonGrid>
279+
template <typename PhotonGrid>
281280
bool Observer::set_boundaries(InterpState& state, size_t i, size_t j, size_t k, Real lg2_nu_obs,
282-
PhotonGrid&... photons) noexcept {
281+
PhotonGrid& photons) noexcept {
283282
if (state.last_hi == k + 1 && state.last_lg2_nu == lg2_nu_obs) {
284283
if (!std::isfinite(state.slope)) {
285284
return false;
@@ -298,12 +297,12 @@ bool Observer::set_boundaries(InterpState& state, size_t i, size_t j, size_t k,
298297
state.lg2_I_nu_lo = state.lg2_I_nu_hi;
299298
} else {
300299
Real lg2_nu_lo = lg2_one_plus_z + lg2_nu_obs - lg2_doppler(i, j, k);
301-
state.lg2_I_nu_lo = 3 * lg2_doppler(i, j, k) + (photons(eff_i, j, k).compute_log2_I_nu(lg2_nu_lo) + ...) +
302-
lg2_emission_area(i, j, k);
300+
state.lg2_I_nu_lo =
301+
3 * lg2_doppler(i, j, k) + photons(eff_i, j, k).compute_log2_I_nu(lg2_nu_lo) + lg2_emission_area(i, j, k);
303302
}
304303

305304
Real lg2_nu_hi = lg2_one_plus_z + lg2_nu_obs - lg2_doppler(i, j, k + 1);
306-
state.lg2_I_nu_hi = 3 * lg2_doppler(i, j, k + 1) + (photons(eff_i, j, k + 1).compute_log2_I_nu(lg2_nu_hi) + ...) +
305+
state.lg2_I_nu_hi = 3 * lg2_doppler(i, j, k + 1) + photons(eff_i, j, k + 1).compute_log2_I_nu(lg2_nu_hi) +
307306
lg2_emission_area(i, j, k + 1);
308307

309308
state.slope = (state.lg2_I_nu_hi - state.lg2_I_nu_lo) / lg2_t_ratio;
@@ -317,8 +316,8 @@ bool Observer::set_boundaries(InterpState& state, size_t i, size_t j, size_t k,
317316
return true;
318317
}
319318

320-
template <typename... PhotonGrid>
321-
MeshGrid Observer::specific_flux(Array const& t_obs, Array const& nu_obs, PhotonGrid&... photons) {
319+
template <typename PhotonGrid>
320+
MeshGrid Observer::specific_flux(Array const& t_obs, Array const& nu_obs, PhotonGrid& photons) {
322321
size_t t_obs_len = t_obs.size();
323322
size_t nu_len = nu_obs.size();
324323

@@ -341,7 +340,7 @@ MeshGrid Observer::specific_flux(Array const& t_obs, Array const& nu_obs, Photon
341340
if (lg2_t_hi < lg2_t_obs(t_idx)) {
342341
continue;
343342
} else {
344-
if (set_boundaries(state, i, j, k, lg2_nu[l], photons...)) {
343+
if (set_boundaries(state, i, j, k, lg2_nu[l], photons)) {
345344
for (; t_idx < t_obs_len && lg2_t_obs(t_idx) <= lg2_t_hi; t_idx++) {
346345
F_nu(l, t_idx) += interpolate(state, i, j, k, lg2_t_obs(t_idx));
347346
}
@@ -361,8 +360,8 @@ MeshGrid Observer::specific_flux(Array const& t_obs, Array const& nu_obs, Photon
361360
return F_nu;
362361
}
363362

364-
template <typename... PhotonGrid>
365-
Array Observer::specific_flux_series(Array const& t_obs, Array const& nu_obs, PhotonGrid&... photons) {
363+
template <typename PhotonGrid>
364+
Array Observer::specific_flux_series(Array const& t_obs, Array const& nu_obs, PhotonGrid& photons) {
366365
size_t t_obs_len = t_obs.size();
367366
size_t nu_len = nu_obs.size();
368367

@@ -388,7 +387,7 @@ Array Observer::specific_flux_series(Array const& t_obs, Array const& nu_obs, Ph
388387
k++;
389388
continue;
390389
} else {
391-
if (set_boundaries(state, i, j, k, lg2_nu[t_idx], photons...)) {
390+
if (set_boundaries(state, i, j, k, lg2_nu(t_idx), photons)) {
392391
F_nu(t_idx) += interpolate(state, i, j, k, lg2_t_obs(t_idx));
393392
}
394393
t_idx++;
@@ -404,25 +403,25 @@ Array Observer::specific_flux_series(Array const& t_obs, Array const& nu_obs, Ph
404403
return F_nu;
405404
}
406405

407-
template <typename... PhotonGrid>
408-
Array Observer::specific_flux(Array const& t_obs, Real nu_obs, PhotonGrid&... photons) {
409-
return xt::view(specific_flux(t_obs, Array({nu_obs}), photons...), 0);
406+
template <typename PhotonGrid>
407+
Array Observer::specific_flux(Array const& t_obs, Real nu_obs, PhotonGrid& photons) {
408+
return xt::view(specific_flux(t_obs, Array({nu_obs}), photons), 0);
410409
}
411410

412-
template <typename... PhotonGrid>
413-
Array Observer::spectrum(Array const& freqs, Real t_obs, PhotonGrid&... photons) {
414-
return xt::view(spectra(freqs, Array({t_obs}), photons...), 0);
411+
template <typename PhotonGrid>
412+
Array Observer::spectrum(Array const& freqs, Real t_obs, PhotonGrid& photons) {
413+
return xt::view(spectra(freqs, Array({t_obs}), photons), 0);
415414
}
416415

417-
template <typename... PhotonGrid>
418-
MeshGrid Observer::spectra(Array const& freqs, Array const& t_obs, PhotonGrid&... photons) {
419-
return xt::transpose(specific_flux(t_obs, freqs, photons...));
416+
template <typename PhotonGrid>
417+
MeshGrid Observer::spectra(Array const& freqs, Array const& t_obs, PhotonGrid& photons) {
418+
return xt::transpose(specific_flux(t_obs, freqs, photons));
420419
}
421420

422-
template <typename... PhotonGrid>
423-
Array Observer::flux(Array const& t_obs, Array const& band_freq, PhotonGrid&... photons) {
421+
template <typename PhotonGrid>
422+
Array Observer::flux(Array const& t_obs, Array const& band_freq, PhotonGrid& photons) {
424423
Array nu_obs = boundary_to_center_log(band_freq);
425-
MeshGrid F_nu = specific_flux(t_obs, nu_obs, photons...);
424+
MeshGrid F_nu = specific_flux(t_obs, nu_obs, photons);
426425
Array flux({t_obs.size()}, 0);
427426
for (size_t i = 0; i < nu_obs.size(); ++i) {
428427
Real dnu = band_freq(i + 1) - band_freq(i);

pybind/mcmc.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ double FluxData::estimate_chi2() const {
8383
double chi_square = 0;
8484
for (size_t i = 0; i < t.size(); ++i) {
8585
double error = Fv_err(i);
86-
if (error == 0)
87-
continue;
86+
//if (error == 0)
87+
// continue;
8888
double diff = Fv_obs(i) - Fv_model(i);
8989
chi_square += weights(i) * (diff * diff) / (error * error);
9090
}
@@ -95,14 +95,15 @@ double MultiBandData::estimate_chi2() const {
9595
double chi_square = 0;
9696
for (size_t i = 0; i < times.size(); ++i) {
9797
double error = errors(i);
98-
if (error == 0)
99-
continue;
98+
//if (error == 0)
99+
// continue;
100100
double diff = fluxes(i) - model_fluxes(i);
101101
chi_square += weights(i) * (diff * diff) / (error * error);
102102
}
103103
for (auto& d : flux_data) {
104104
chi_square += d.estimate_chi2();
105105
}
106+
106107
return chi_square;
107108
}
108109

0 commit comments

Comments
 (0)