Skip to content

Commit 30947f1

Browse files
authored
Improve comments and documentation for EnsembleProblem.jl
1 parent 0eb90b8 commit 30947f1

File tree

1 file changed

+36
-131
lines changed

1 file changed

+36
-131
lines changed

src/ensemble/ensemble_problems.jl

Lines changed: 36 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,24 @@ EnsembleProblem(prob::AbstractSciMLProblem;
2828
`repeat` is the iteration of the repeat. At first, it is `1`, but if
2929
`rerun` was true this will be `2`, `3`, etc. counting the number of times
3030
problem `i` has been repeated.
31-
- `reduction`: This function is used to aggregate the results in each simulation batch. By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`. The `I` is a range of indices corresponding to the trajectories for the current batch.
3231
33-
### Arguments:
34-
- `u`: The solution from the current ensemble run. This is the accumulated data that gets updated in each batch.
35-
- `data`: The results from the current batch of simulations. This is typically some data (e.g., variable values, time steps) that is merged with `u`.
36-
- `I`: A range of indices corresponding to the simulations in the current batch. This provides the trajectory indices for the batch.
37-
38-
### Returns:
39-
- `(new_data, has_converged)`: A tuple where:
40-
- `new_data`: The updated accumulated data, typically the result of appending `data` to `u`.
41-
- `has_converged`: A boolean indicating whether the simulation has converged and should terminate early. If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is `false`, meaning the simulation will not stop early.
32+
- `reduction`: This function is used to aggregate the results in each simulation batch.
33+
By default, it appends the `data` from the batch to `u`, which is initialized via `u_data`.
34+
The `I` is a range of indices corresponding to the trajectories for the current batch.
35+
### Arguments:
36+
- `u`: The solution from the current ensemble run. This is the accumulated data that gets
37+
updated in each batch.
38+
- `data`: The results from the current batch of simulations. This is typically some data
39+
(e.g., variable values, time steps) that is merged with `u`.
40+
- `I`: A range of indices corresponding to the simulations in the current batch. This provides
41+
the trajectory indices for the batch.
42+
43+
### Returns:
44+
- `(new_data, has_converged)`: A tuple where:
45+
- `new_data`: The updated accumulated data, typically the result of appending `data` to `u`.
46+
- `has_converged`: A boolean indicating whether the simulation has converged and should terminate early.
47+
If `true`, the simulation will stop early. If `false`, the simulation will continue. By default, this is
48+
`false`, meaning the simulation will not stop early.
4249
4350
- `u_init`: The initial form of the object that gets updated in-place inside the
4451
`reduction` function.
@@ -88,59 +95,19 @@ Thus, the ensemble simulation would return as its data an array which is the
8895
end value of the 2nd dependent variable for each of the runs.
8996
"""
9097

91-
# Defines a structure to manage an ensemble (batch) of problems.
92-
# Each field controls how the ensemble behaves during simulation.
93-
94-
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
95-
prob::T # The original base problem to replicate or modify.
96-
prob_func::T2 # A function defining how to generate each subproblem (e.g., changing initial conditions).
97-
output_func::T3 # A function to post-process each individual simulation result.
98-
reduction::T4 # A function to combine results from all simulations.
99-
u_init::T5 # The initial container used to accumulate the results.
100-
safetycopy::Bool # Whether to copy the problem when creating subproblems (to avoid unintended modifications).
101-
end
102-
103-
# Returns the same problem without modification.
104-
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
105-
106-
# Returns the solution as-is, along with a flag (false) indicating no early termination.
107-
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
108-
109-
# Appends new data to the accumulated data, no early convergence.
110-
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
111-
112-
# Selects the i-th problem from a vector of problems.
113-
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
114-
115-
# Constructor: creates an EnsembleProblem when the input is a vector of problems (DEPRECATED).
116-
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
117-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
118-
Ensembles Simulations Interface page for more details", :EnsembleProblem)
119-
invoke(EnsembleProblem,
120-
Tuple{Any},
121-
prob;
122-
prob_func = DEFAULT_VECTOR_PROB_FUNC,
123-
kwargs...)
124-
end
125-
126-
# Main constructor: creates an EnsembleProblem with optional custom behavior.
127-
function EnsembleProblem(prob;
128-
prob_func = DEFAULT_PROB_FUNC,
129-
output_func = DEFAULT_OUTPUT_FUNC,
130-
reduction = DEFAULT_REDUCTIO"""
98+
"""
13199
$(TYPEDEF)
132100
133101
Defines a structure to manage an ensemble (batch) of problems.
134102
Each field controls how the ensemble behaves during simulation.
135103
136-
## Arguments:
137-
138-
- `prob`: The original base problem to replicate or modify.
139-
- `prob_func`: A function that defines how to generate each subproblem (e.g., changing initial conditions).
140-
- `output_func`: A function to post-process each individual simulation result.
141-
- `reduction`: A function to combine results from all simulations.
142-
- `u_init`: The initial container used to accumulate the results.
143-
- `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications).
104+
## Arguments
105+
- `prob`: The original base problem to replicate or modify.
106+
- `prob_func`: A function that defines how to generate each subproblem.
107+
- `output_func`: A function to post-process each individual simulation result.
108+
- `reduction`: A function to combine results from all simulations.
109+
- `u_init`: The initial container used to accumulate the results.
110+
- `safetycopy`: Whether to copy the problem when creating subproblems (to avoid unintended modifications).
144111
"""
145112
struct EnsembleProblem{T, T2, T3, T4, T5} <: AbstractEnsembleProblem
146113
prob::T
@@ -153,38 +120,35 @@ end
153120

154121
"""
155122
Returns the same problem without modification.
156-
157123
"""
158124
DEFAULT_PROB_FUNC(prob, i, repeat) = prob
159125

160126
"""
161127
Returns the solution as-is, along with `false` indicating no rerun.
162-
163128
"""
164129
DEFAULT_OUTPUT_FUNC(sol, i) = (sol, false)
165130

166131
"""
167132
Appends new data to the accumulated data and returns `false` to indicate no early termination.
168-
169133
"""
170134
DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
171135

172136
"""
173137
Selects the i-th problem from a vector of problems.
174-
175138
"""
176139
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
177140

178141
"""
179142
$(TYPEDEF)
180143
181-
Constructor that creates an EnsembleProblem when the input is a vector of problems.
144+
Constructor for deprecated usage where a vector of problems is passed directly.
182145
183146
!!! warning
184147
This constructor is deprecated. Use the standard ensemble syntax with `prob_func` instead.
185148
"""
186149
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
187-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel Ensembles Simulations Interface page for more details", :EnsembleProblem)
150+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
151+
Ensembles Simulations Interface page for more details", :EnsembleProblem)
188152
invoke(EnsembleProblem,
189153
Tuple{Any},
190154
prob;
@@ -197,15 +161,14 @@ $(TYPEDEF)
197161
198162
Main constructor for `EnsembleProblem`.
199163
200-
## Arguments:
164+
## Keyword Arguments
201165
202166
- `prob`: The base problem.
203167
- `prob_func`: Function to modify the base problem per trajectory.
204168
- `output_func`: Function to extract output from a solution.
205169
- `reduction`: Function to aggregate results.
206170
- `u_init`: Initial value for aggregation.
207171
- `safetycopy`: Whether to deepcopy the problem before modifying.
208-
209172
"""
210173
function EnsembleProblem(prob;
211174
prob_func = DEFAULT_PROB_FUNC,
@@ -224,7 +187,6 @@ end
224187
$(TYPEDEF)
225188
226189
Alternate constructor that uses only keyword arguments.
227-
228190
"""
229191
function EnsembleProblem(; prob,
230192
prob_func = DEFAULT_PROB_FUNC,
@@ -238,15 +200,15 @@ end
238200
"""
239201
$(TYPEDEF)
240202
241-
Constructor for NonlinearProblem.
203+
Constructor that is used for NOnlinearProblem.
242204
243205
!!! warning
244206
This dispatch is deprecated. See the Parallel Ensembles Simulations Interface page.
245-
246207
"""
247208
function SciMLBase.EnsembleProblem(
248209
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
249-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel Ensembles Simulations Interface page for more details", :EnsembleProblem)
210+
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel \
211+
Ensembles Simulations Interface page for more details", :EnsembleProblem)
250212
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
251213
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
252214
end
@@ -256,11 +218,10 @@ $(TYPEDEF)
256218
257219
Defines a weighted version of an `EnsembleProblem`, where different simulations contribute unequally.
258220
259-
## Arguments:
221+
## Fields
260222
261223
- `ensembleprob`: The base ensemble problem.
262224
- `weights`: A vector of weights corresponding to each simulation.
263-
264225
"""
265226
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
266227
AbstractEnsembleProblem
@@ -269,9 +230,7 @@ struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVect
269230
end
270231

271232
"""
272-
273233
Returns a list of all accessible properties, including those from the inner ensemble and `:weights`.
274-
275234
"""
276235
function Base.propertynames(e::WeightedEnsembleProblem)
277236
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
@@ -281,7 +240,6 @@ end
281240
Accesses properties of a `WeightedEnsembleProblem`.
282241
283242
Returns `weights` or delegates to the underlying ensemble.
284-
285243
"""
286244
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
287245
f === :weights && return getfield(e, :weights)
@@ -294,67 +252,14 @@ $(TYPEDEF)
294252
295253
Constructor for `WeightedEnsembleProblem`. Ensures weights sum to 1 and matches problem count.
296254
255+
## Keyword Arguments
256+
257+
- `weights`: A vector of weights for each trajectory.
297258
"""
298259
function WeightedEnsembleProblem(args...; weights, kwargs...)
299260
@assert sum(weights) 1
300261
ep = EnsembleProblem(args...; kwargs...)
301262
@assert length(ep.prob) == length(weights)
302263
WeightedEnsembleProblem(ep, weights)
303264
end
304-
N,
305-
u_init = nothing,
306-
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
307-
_prob_func = prepare_function(prob_func)
308-
_output_func = prepare_function(output_func)
309-
_reduction = prepare_function(reduction)
310-
_u_init = prepare_initial_state(u_init)
311-
EnsembleProblem(prob, _prob_func, _output_func, _reduction, _u_init, safetycopy)
312-
end
313-
314-
# Alternative constructor that accepts parameters through keyword arguments (especially used internally).
315-
function EnsembleProblem(; prob,
316-
prob_func = DEFAULT_PROB_FUNC,
317-
output_func = DEFAULT_OUTPUT_FUNC,
318-
reduction = DEFAULT_REDUCTION,
319-
u_init = nothing, p = nothing,
320-
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
321-
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
322-
end
323265

324-
#since NonlinearProblem might want to use this dispatch as well
325-
#Special constructor used for creating an EnsembleProblem where initial states vary.
326-
function SciMLBase.EnsembleProblem(
327-
prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
328-
Base.depwarn("This dispatch is deprecated for the standard ensemble syntax. See the Parallel
329-
Ensembles Simulations Interface page for more details", :EnsebleProblem)
330-
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
331-
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
332-
end
333-
334-
# Defines a weighted version of an EnsembleProblem, where different simulations contribute unequally.
335-
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
336-
AbstractEnsembleProblem
337-
ensembleprob::T1 # The base ensemble problem.
338-
weights::T2 # A vector of weights corresponding to each simulation.
339-
end
340-
341-
# Allow accessing all properties from the base ensemble plus the new weights field.
342-
function Base.propertynames(e::WeightedEnsembleProblem)
343-
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
344-
end
345-
346-
# Getter for fields: either return weights, ensembleprob, or delegate to the underlying ensemble.
347-
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
348-
f === :weights && return getfield(e, :weights)
349-
f === :ensembleprob && return getfield(e, :ensembleprob)
350-
return getproperty(getfield(e, :ensembleprob), f)
351-
end
352-
353-
# Constructor for WeightedEnsembleProblem, checks that weights sum to ~1.
354-
function WeightedEnsembleProblem(args...; weights, kwargs...)
355-
# TODO: allow skipping checks?
356-
@assert sum(weights) 1
357-
ep = EnsembleProblem(args...; kwargs...)
358-
@assert length(ep.prob) == length(weights)
359-
WeightedEnsembleProblem(ep, weights)
360-
end

0 commit comments

Comments
 (0)