Skip to content

Commit 9676006

Browse files
committed
feat: enhance SignalPulse with caching and streaming capabilities
1 parent 592518f commit 9676006

File tree

1 file changed

+173
-15
lines changed

1 file changed

+173
-15
lines changed

src/algorithms/digi/PulseGeneration.cc

Lines changed: 173 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <functional>
22+
#include <limits>
2223
#include <stdexcept>
2324
#include <string>
2425
#include <unordered_map>
@@ -31,11 +32,141 @@ namespace eicrecon {
3132
class SignalPulse {
3233

3334
public:
35+
// State for streaming fast-path iteration (when cache resolution == timestep)
36+
struct StreamState {
37+
int base_index;
38+
double frac; // constant interpolation fraction across iterations
39+
double v_floor_n; // normalized value at base_index
40+
double v_ceil_n; // normalized value at base_index+1
41+
};
42+
3443
virtual ~SignalPulse() = default; // Virtual destructor
3544

36-
virtual double operator()(double time, double charge) = 0;
45+
// Public capability query: caching is only valid for pulses linear in charge
46+
bool supportsCaching() const { return isLinearInCharge(); }
47+
48+
// Main interface - with caching and interpolation
49+
double operator()(double time, double charge) {
50+
// If caching is disabled, use direct evaluation
51+
if (m_cache_resolution <= 0.0) {
52+
return evaluate(time, charge);
53+
}
54+
55+
// Use caching with linear interpolation
56+
// Normalize time to cache resolution
57+
double normalized_time = time / m_cache_resolution;
58+
int index = static_cast<int>(std::floor(normalized_time));
59+
60+
// Get or compute cached values at floor and ceil indices
61+
double value_floor = getCachedValue(index, charge);
62+
double value_ceil = getCachedValue(index + 1, charge);
63+
64+
// Linear interpolation
65+
double fraction = normalized_time - index;
66+
return value_floor + fraction * (value_ceil - value_floor);
67+
}
3768

3869
virtual double getMaximumTime() const = 0;
70+
71+
// Enable caching with specified resolution (typically the timestep)
72+
void enableCache(double resolution) {
73+
// Validate linearity first to avoid allocating cache when unsupported
74+
validateLinearity();
75+
76+
m_cache_resolution = resolution;
77+
// Pre-size array-backed cache and reset filled flags
78+
const std::size_t size = 2 * MAX_CACHE_SPAN_BINS + 3; // allow indices in [-MAX, MAX+2]
79+
m_cache_values.assign(size, std::numeric_limits<double>::quiet_NaN());
80+
m_cache_size = size;
81+
m_cache_offset = static_cast<int>(MAX_CACHE_SPAN_BINS + 1);
82+
}
83+
84+
// Expert fast-path: fetch normalized cached sample at integer time index.
85+
// Returns value for charge=1; callers can scale by actual charge.
86+
// Falls back to direct evaluate if index is outside the preallocated span.
87+
double sampleNormalizedAtIndex(int time_index) const {
88+
const int vec_idx = time_index + m_cache_offset;
89+
if (vec_idx < 0 || static_cast<std::size_t>(vec_idx) >= m_cache_size) {
90+
return const_cast<SignalPulse*>(this)->evaluate(time_index * m_cache_resolution, 1.0);
91+
}
92+
double& slot = m_cache_values[static_cast<std::size_t>(vec_idx)];
93+
if (std::isnan(slot)) {
94+
const double time = time_index * m_cache_resolution;
95+
slot = const_cast<SignalPulse*>(this)->evaluate(time, 1.0);
96+
}
97+
return slot;
98+
}
99+
100+
// Prepare streaming state for a pulse starting at signal_time with hit at time
101+
// Assumes cache resolution equals the timestep for O(1) index streaming.
102+
StreamState prepareStreaming(double signal_time, double time, double timestep) const {
103+
const double nt0 = (signal_time - time) / timestep;
104+
const int base_index = static_cast<int>(std::floor(nt0));
105+
const double frac = nt0 - base_index; // constant across iterations
106+
const double v_floor_n = sampleNormalizedAtIndex(base_index);
107+
const double v_ceil_n = sampleNormalizedAtIndex(base_index + 1);
108+
return {base_index, frac, v_floor_n, v_ceil_n};
109+
}
110+
111+
protected:
112+
// Derived classes implement the actual pulse evaluation
113+
// IMPORTANT: evaluate() MUST be linear in charge for caching to work correctly!
114+
// That is: evaluate(t, a*q) must equal a * evaluate(t, q) for all t, q, a
115+
// The cache stores normalized values (charge=1) and scales by actual charge.
116+
virtual double evaluate(double time, double charge) = 0;
117+
118+
// Override this in derived classes if the pulse is NOT linear in charge
119+
// Default assumes linearity (which is true for LandauPulse and most physical pulses)
120+
virtual bool isLinearInCharge() const { return true; }
121+
122+
private:
123+
double m_cache_resolution = 0.0; // 0 means caching disabled
124+
// Set a maximum cache span of +/- MAX_CACHE_SPAN_BINS time indices, centered on 0.
125+
// This eliminates hash lookups by using O(1) array indexing.
126+
static constexpr std::size_t MAX_CACHE_SPAN_BINS = 10000;
127+
mutable std::vector<double> m_cache_values; // normalized values for charge=1
128+
std::size_t m_cache_size = 0; // cached to avoid size() calls in hot path
129+
int m_cache_offset = 0; // vector index = time_index + m_cache_offset
130+
131+
// Validate that the pulse function is linear in charge
132+
void validateLinearity() const {
133+
if (!isLinearInCharge()) {
134+
throw std::runtime_error(
135+
"SignalPulse caching was requested, but this pulse reports isLinearInCharge()==false. "
136+
"Caching only supports pulses that are linear in charge. Avoid calling enableCache() for "
137+
"non-linear pulses, or override isLinearInCharge() to return true if appropriate.");
138+
}
139+
// Runtime verification
140+
const double t_test = 1.0;
141+
const double q1 = 1.0, q2 = 2.0;
142+
const double v1 = const_cast<SignalPulse*>(this)->evaluate(t_test, q1);
143+
const double v2 = const_cast<SignalPulse*>(this)->evaluate(t_test, q2);
144+
const double ratio = std::abs(v2 / v1);
145+
const double expected_ratio = q2 / q1;
146+
if (std::abs(ratio - expected_ratio) > 0.01 * expected_ratio) {
147+
throw std::runtime_error("SignalPulse caching linearity check FAILED: the pulse reported "
148+
"linear (isLinearInCharge()==true) "
149+
"but evaluate(t, a*q) != a * evaluate(t, q). Fix evaluate() to be "
150+
"linear in charge or disable caching.");
151+
}
152+
}
153+
154+
// Get cached value or compute and cache it
155+
double getCachedValue(int time_index, double charge) const {
156+
// Fast O(1) array-backed cache lookup. We store values for charge=1 and scale by 'charge'.
157+
const int vec_idx = time_index + m_cache_offset;
158+
if (vec_idx < 0 || static_cast<std::size_t>(vec_idx) >= m_cache_size) {
159+
// Outside the preallocated cache span: fall back to direct evaluation (no caching).
160+
return const_cast<SignalPulse*>(this)->evaluate(time_index * m_cache_resolution, charge);
161+
}
162+
double& slot = m_cache_values[static_cast<std::size_t>(vec_idx)];
163+
if (std::isnan(slot)) {
164+
// Compute and cache normalized value (charge=1)
165+
const double time = time_index * m_cache_resolution;
166+
slot = const_cast<SignalPulse*>(this)->evaluate(time, 1.0);
167+
}
168+
return slot * charge;
169+
}
39170
};
40171

41172
// ----------------------------------------------------------------------------
@@ -58,13 +189,15 @@ class LandauPulse : public SignalPulse {
58189
}
59190
};
60191

61-
double operator()(double time, double charge) override {
192+
double getMaximumTime() const override { return m_hit_sigma_offset * m_sigma_analog; }
193+
194+
protected:
195+
// LandauPulse is linear in charge: evaluate(t, a*q) = a*q*gain*Landau(...) = a*evaluate(t, q)
196+
double evaluate(double time, double charge) override {
62197
return charge * m_gain *
63198
TMath::Landau(time, m_hit_sigma_offset * m_sigma_analog, m_sigma_analog, kTRUE);
64199
}
65200

66-
double getMaximumTime() const override { return m_hit_sigma_offset * m_sigma_analog; }
67-
68201
private:
69202
double m_gain = 1.0;
70203
double m_sigma_analog = 1.0;
@@ -79,7 +212,7 @@ class EvaluatorPulse : public SignalPulse {
79212
std::vector<std::string> keys = {"time", "charge"};
80213
for (std::size_t i = 0; i < params.size(); i++) {
81214
std::string p = "param" + std::to_string(i);
82-
//Check the expression contains the parameter
215+
// Check the expression contains the parameter
83216
if (expression.find(p) == std::string::npos) {
84217
throw std::runtime_error("Parameter " + p + " not found in expression");
85218
}
@@ -99,14 +232,17 @@ class EvaluatorPulse : public SignalPulse {
99232
m_evaluator = serviceSvc.service<EvaluatorSvc>("EvaluatorSvc")->_compile(expression, keys);
100233
};
101234

102-
double operator()(double time, double charge) override {
235+
double getMaximumTime() const override { return 0; }
236+
237+
protected:
238+
// EvaluatorPulse: Linearity depends on the expression provided by user
239+
// Most physical pulse shapes multiply by charge, making them linear
240+
double evaluate(double time, double charge) override {
103241
param_map["time"] = time;
104242
param_map["charge"] = charge;
105243
return m_evaluator(param_map);
106244
}
107245

108-
double getMaximumTime() const override { return 0; }
109-
110246
private:
111247
std::unordered_map<std::string, double> param_map;
112248
std::function<double(const std::unordered_map<std::string, double>&)> m_evaluator;
@@ -161,11 +297,17 @@ void HitAdapter<edm4hep::SimCalorimeterHit>::addRelations(MutablePulseType& puls
161297
#endif
162298

163299
template <typename HitT> void PulseGeneration<HitT>::init() {
164-
m_pulse =
300+
// Factory returns unique_ptr; construct shared_ptr to keep ownership simple across TUs
301+
auto uptr =
165302
PulseShapeFactory::createPulseShape(m_cfg.pulse_shape_function, m_cfg.pulse_shape_params);
166-
m_min_sampling_time = m_cfg.min_sampling_time;
303+
m_pulse = std::shared_ptr<SignalPulse>(std::move(uptr));
167304

168-
m_min_sampling_time = std::max<double>(m_pulse->getMaximumTime(), m_min_sampling_time);
305+
// Enable caching with the timestep as the resolution only if supported (linear in charge)
306+
if (m_pulse->supportsCaching()) {
307+
m_pulse->enableCache(m_cfg.timestep);
308+
}
309+
310+
m_min_sampling_time = std::max<double>(m_pulse->getMaximumTime(), m_cfg.min_sampling_time);
169311
}
170312

171313
template <typename HitT>
@@ -176,30 +318,46 @@ void PulseGeneration<HitT>::process(
176318
auto [rawPulses] = output;
177319

178320
for (const auto& hit : *simhits) {
179-
const auto [time, charge] = HitAdapter<HitT>::getPulseSources(hit);
321+
// Avoid repeated shared_ptr access overhead in hot path; keep read-only API here
322+
const SignalPulse* const pulsePtr = m_pulse.get();
323+
const auto [time, charge] = HitAdapter<HitT>::getPulseSources(hit);
180324
// Calculate nearest timestep to the hit time rounded down (assume clocks aligned with time 0)
181325
double signal_time = m_cfg.timestep * std::floor(time / m_cfg.timestep);
182326

183327
bool passed_threshold = false;
184328
std::uint32_t skip_bins = 0;
185329
float integral = 0;
330+
186331
std::vector<float> pulse;
332+
pulse.reserve(m_cfg.max_time_bins);
333+
334+
// Streaming fast-path: cache resolution equals timestep, indices advance by +1 per bin
335+
// Encapsulated precomputation for base_index, frac, and initial cached values
336+
auto state = pulsePtr->prepareStreaming(signal_time, time, m_cfg.timestep);
187337

188338
for (std::uint32_t i = 0; i < m_cfg.max_time_bins; i++) {
189-
double t = signal_time + i * m_cfg.timestep - time;
190-
auto signal = (*m_pulse)(t, charge);
191-
if (std::abs(signal) < m_cfg.ignore_thres) {
339+
// One cached fetch per iteration after the first
340+
double signal = charge * (state.v_floor_n + state.frac * (state.v_ceil_n - state.v_floor_n));
341+
if (std::fabs(signal) < m_cfg.ignore_thres) {
192342
if (!passed_threshold) {
193343
skip_bins = i;
194344
continue;
195345
}
346+
// t = (i * timestep) - (nt0 * timestep) = (i - nt0) * timestep
347+
const double nt0 = (signal_time - time) / m_cfg.timestep;
348+
const double t = (static_cast<double>(i) - nt0) * m_cfg.timestep;
196349
if (t > m_min_sampling_time) {
197350
break;
198351
}
199352
}
200353
passed_threshold = true;
201354
pulse.push_back(signal);
202355
integral += signal;
356+
357+
// Advance streaming cache state for next iteration
358+
state.base_index += 1;
359+
state.v_floor_n = state.v_ceil_n;
360+
state.v_ceil_n = pulsePtr->sampleNormalizedAtIndex(state.base_index + 1);
203361
}
204362

205363
if (!passed_threshold) {

0 commit comments

Comments
 (0)