Skip to content

Commit 0eb90b8

Browse files
author
Fengyu Zhang
committed
Improve comments and documentation for EnsembleProblem
1 parent cc59a15 commit 0eb90b8

File tree

1 file changed

+214
-10
lines changed

1 file changed

+214
-10
lines changed

src/ensemble/ensemble_problems.jl

Lines changed: 214 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,18 @@ EnsembleProblem(prob::AbstractSciMLProblem;
2828
`repeat` is the iteration of the repeat. At first, it is `1`, but if
2929
`rerun` was true this will be `2`, `3`, etc. counting the number of times
3030
problem `i` has been repeated.
31-
- `reduction`: This function determines how to reduce the data in each batch.
32-
Defaults to appending the `data` into `u`, initialised via `u_data`, from
33-
the batches. `I` is a range of indices giving the trajectories corresponding
34-
to the batches. The second part of the output determines whether the simulation
35-
has converged. If `true`, the simulation will exit early. By default, this is
36-
always `false`.
31+
- `reduction`: This function is used to aggregate the results in each simulation batch. By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`. The `I` is a range of indices corresponding to the trajectories for the current batch.
32+
33+
### Arguments:
34+
- `u`: The solution from the current ensemble run. This is the accumulated data that gets updated in each batch.
35+
- `data`: The results from the current batch of simulations. This is typically some data (e.g., variable values, time steps) that is merged with `u`.
36+
- `I`: A range of indices corresponding to the simulations in the current batch. This provides the trajectory indices for the batch.
37+
38+
### Returns:
39+
- `(new_data, has_converged)`: A tuple where:
40+
- `new_data`: The updated accumulated data, typically the result of appending `data` to `u`.
41+
- `has_converged`: A boolean indicating whether the simulation has converged and should terminate early. If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is `false`, meaning the simulation will not stop early.
42+
3743
- `u_init`: The initial form of the object that gets updated in-place inside the
3844
`reduction` function.
3945
- `safetycopy`: Determines whether a safety `deepcopy` is called on the `prob`
@@ -81,6 +87,61 @@ output_func(sol, i) = (sol[end, 2], false)
8187
Thus, the ensemble simulation would return as its data an array which is the
8288
end value of the 2nd dependent variable for each of the runs.
8389
"""
90+
91+
# Defines a structure to manage an ensemble (batch) of problems.
92+
# Each field controls how the ensemble behaves during simulation.
93+
94+
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
95+
prob::T # The original base problem to replicate or modify.
96+
prob_func::T2 # A function defining how to generate each subproblem (e.g., changing initial conditions).
97+
output_func::T3 # A function to post-process each individual simulation result.
98+
reduction::T4 # A function to combine results from all simulations.
99+
u_init::T5 # The initial container used to accumulate the results.
100+
safetycopy::Bool # Whether to copy the problem when creating subproblems (to avoid unintended modifications).
101+
end
102+
103+
# Returns the same problem without modification.
104+
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
105+
106+
# Returns the solution as-is, along with a flag (false) indicating no early termination.
107+
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
108+
109+
# Appends new data to the accumulated data, no early convergence.
110+
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
111+
112+
# Selects the i-th problem from a vector of problems.
113+
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
114+
115+
# Constructor: creates an EnsembleProblem when the input is a vector of problems (DEPRECATED).
116+
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
117+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
118+
Ensembles Simulations Interface page for more details", :EnsembleProblem)
119+
invoke(EnsembleProblem,
120+
Tuple{Any},
121+
prob;
122+
prob_func = DEFAULT_VECTOR_PROB_FUNC,
123+
kwargs...)
124+
end
125+
126+
# Main constructor: creates an EnsembleProblem with optional custom behavior.
127+
function EnsembleProblem(prob;
128+
prob_func = DEFAULT_PROB_FUNC,
129+
output_func = DEFAULT_OUTPUT_FUNC,
130+
reduction = DEFAULT_REDUCTIO"""
131+
$(TYPEDEF)
132+
133+
Defines a structure to manage an ensemble (batch) of problems.
134+
Each field controls how the ensemble behaves during simulation.
135+
136+
## Arguments:
137+
138+
- `prob`: The original base problem to replicate or modify.
139+
- `prob_func`: A function that defines how to generate each subproblem (e.g., changing initial conditions).
140+
- `output_func`: A function to post-process each individual simulation result.
141+
- `reduction`: A function to combine results from all simulations.
142+
- `u_init`: The initial container used to accumulate the results.
143+
- `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications).
144+
"""
84145
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
85146
prob::T
86147
prob_func::T2
@@ -90,19 +151,62 @@ struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
90151
safetycopy::Bool
91152
end
92153

154+
"""
155+
Returns the same problem without modification.
156+
157+
"""
93158
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
159+
160+
"""
161+
Returns the solution as-is, along with `false` indicating no rerun.
162+
163+
"""
94164
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
165+
166+
"""
167+
Appends new data to the accumulated data and returns `false` to indicate no early termination.
168+
169+
"""
95170
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
171+
172+
"""
173+
Selects the i-th problem from a vector of problems.
174+
175+
"""
96176
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
177+
178+
"""
179+
$(TYPEDEF)
180+
181+
Constructor that creates an EnsembleProblem when the input is a vector of problems.
182+
183+
!!! warning
184+
This constructor is deprecated. Use the standard ensemble syntax with `prob_func` instead.
185+
"""
97186
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
98-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
99-
Ensembles Simulations Interface page for more details", :EnsembleProblem)
187+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel Ensembles Simulations Interface page for more details", :EnsembleProblem)
100188
invoke(EnsembleProblem,
101189
Tuple{Any},
102190
prob;
103191
prob_func = DEFAULT_VECTOR_PROB_FUNC,
104192
kwargs...)
105193
end
194+
195+
"""
196+
$(TYPEDEF)
197+
198+
Main constructor for `EnsembleProblem`.
199+
200+
## Arguments:
201+
202+
- `prob`: The base problem.
203+
- `prob_func`: Function to modify the base problem per trajectory.
204+
- `output_func`: Function to extract output from a solution.
205+
- `reduction`: Function to aggregate results.
206+
- `u_init`: Initial value for aggregation.
207+
- `safetycopy`: Whether to deepcopy the problem before modifying.
208+
209+
"""
106210
function EnsembleProblem(prob;
107211
prob_func = DEFAULT_PROB_FUNC,
108212
output_func = DEFAULT_OUTPUT_FUNC,
@@ -116,6 +220,98 @@ function EnsembleProblem(prob;
116220
EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy)
117221
end
118222

223+
"""
224+
$(TYPEDEF)
225+
226+
Alternate constructor that uses only keyword arguments.
227+
228+
"""
229+
function EnsembleProblem(; prob,
230+
prob_func = DEFAULT_PROB_FUNC,
231+
output_func = DEFAULT_OUTPUT_FUNC,
232+
reduction = DEFAULT_REDUCTION,
233+
u_init = nothing, p = nothing,
234+
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
235+
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
236+
end
237+
238+
"""
239+
$(TYPEDEF)
240+
241+
Constructor for NonlinearProblem.
242+
243+
!!! warning
244+
This dispatch is deprecated. See the Parallel Ensembles Simulations Interface page.
245+
246+
"""
247+
function SciMLBase.EnsembleProblem(
248+
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
249+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel Ensembles Simulations Interface page for more details", :EnsembleProblem)
250+
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
251+
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
252+
end
253+
254+
"""
255+
$(TYPEDEF)
256+
257+
Defines a weighted version of an `EnsembleProblem`, where different simulations contribute unequally.
258+
259+
## Arguments:
260+
261+
- `ensembleprob`: The base ensemble problem.
262+
- `weights`: A vector of weights corresponding to each simulation.
263+
264+
"""
265+
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
266+
AbstractEnsembleProblem
267+
ensembleprob::T1
268+
weights::T2
269+
end
270+
271+
"""
272+
273+
Returns a list of all accessible properties, including those from the inner ensemble and `:weights`.
274+
275+
"""
276+
function Base.propertynames(e::WeightedEnsembleProblem)
277+
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
278+
end
279+
280+
"""
281+
Accesses properties of a `WeightedEnsembleProblem`.
282+
283+
Returns `weights` or delegates to the underlying ensemble.
284+
285+
"""
286+
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
287+
f === :weights && return getfield(e, :weights)
288+
f === :ensembleprob && return getfield(e, :ensembleprob)
289+
return getproperty(getfield(e, :ensembleprob), f)
290+
end
291+
292+
"""
293+
$(TYPEDEF)
294+
295+
Constructor for `WeightedEnsembleProblem`. Ensures weights sum to 1 and matches problem count.
296+
297+
"""
298+
function WeightedEnsembleProblem(args...; weights, kwargs...)
299+
@assert sum(weights) 1
300+
ep = EnsembleProblem(args...; kwargs...)
301+
@assert length(ep.prob) == length(weights)
302+
WeightedEnsembleProblem(ep, weights)
303+
end
304+
N,
305+
u_init = nothing,
306+
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
307+
_prob_func = prepare_function(prob_func)
308+
_output_func = prepare_function(output_func)
309+
_reduction = prepare_function(reduction)
310+
_u_init = prepare_initial_state(u_init)
311+
EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy)
312+
end
313+
314+
# Alternative constructor that accepts parameters through keyword arguments (especially used internally).
119315
function EnsembleProblem(; prob,
120316
prob_func = DEFAULT_PROB_FUNC,
121317
output_func = DEFAULT_OUTPUT_FUNC,
@@ -126,6 +322,7 @@ function EnsembleProblem(; prob,
126322
end
127323

128324
#since NonlinearProblem might want to use this dispatch as well
325+
#Special constructor used for creating an EnsembleProblem where initial states vary.
129326
function SciMLBase.EnsembleProblem(
130327
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
131328
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
@@ -134,19 +331,26 @@ function SciMLBase.EnsembleProblem(
134331
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
135332
end
136333

334+
# Defines a weighted version of an EnsembleProblem, where different simulations contribute unequally.
137335
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
138336
AbstractEnsembleProblem
139-
ensembleprob::T1
140-
weights::T2
337+
ensembleprob::T1 # The base ensemble problem.
338+
weights::T2 # A vector of weights corresponding to each simulation.
141339
end
340+
341+
# Allow accessing all properties from the base ensemble plus the new weights field.
142342
function Base.propertynames(e::WeightedEnsembleProblem)
143343
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
144344
end
345+
346+
# Getter for fields: either return weights, ensembleprob, or delegate to the underlying ensemble.
145347
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
146348
f === :weights && return getfield(e, :weights)
147349
f === :ensembleprob && return getfield(e, :ensembleprob)
148350
return getproperty(getfield(e, :ensembleprob), f)
149351
end
352+
353+
# Constructor for WeightedEnsembleProblem, checks that weights sum to ~1.
150354
function WeightedEnsembleProblem(args...; weights, kwargs...)
151355
# TODO: allow skipping checks?
152356
@assert sum(weights) 1

0 commit comments

Comments
 (0)