Skip to content

Commit e283773

Browse files
committed
add AbstractProblemType
1 parent 3db3a3a commit e283773

File tree

6 files changed

+131
-94
lines changed

6 files changed

+131
-94
lines changed

src/SciMLBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,8 @@ include("operators/common_defaults.jl")
688688
include("symbolic_utils.jl")
689689
include("performance_warnings.jl")
690690

691+
abstract type AbstractProblemType end
692+
691693
include("problems/discrete_problems.jl")
692694
include("problems/implicit_discrete_problems.jl")
693695
include("problems/steady_state_problems.jl")

src/problems/basic_problems.jl

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct LinearProblem{uType, isinplace, F, bType, P, K} <:
6060
p::P
6161
kwargs::K
6262
@add_kwonly function LinearProblem{iip}(A, b, p = NullParameters(); u0 = nothing,
63-
kwargs...) where {iip}
63+
kwargs...) where {iip}
6464
warn_paramtype(p)
6565
new{typeof(u0), iip, typeof(A), typeof(b), typeof(p), typeof(kwargs)}(A, b, u0, p,
6666
kwargs)
@@ -82,7 +82,7 @@ TruncatedStacktraces.@truncate_stacktrace LinearProblem 1
8282
"""
8383
$(TYPEDEF)
8484
"""
85-
struct StandardNonlinearProblem end
85+
struct StandardNonlinearProblem <: AbstractProblemType end
8686

8787
@doc doc"""
8888
@@ -134,20 +134,20 @@ every solve call.
134134
* `p`: The parameters for the problem. Defaults to `NullParameters`.
135135
* `kwargs`: The keyword arguments passed on to the solvers.
136136
"""
137-
struct IntervalNonlinearProblem{isinplace, tType, P, F, K, PT} <:
137+
struct IntervalNonlinearProblem{isinplace, tType, P, F, K, PT <: AbstractProblemType} <:
138138
AbstractIntervalNonlinearProblem{nothing, isinplace}
139139
f::F
140140
tspan::tType
141141
p::P
142142
problem_type::PT
143143
kwargs::K
144144
@add_kwonly function IntervalNonlinearProblem{iip}(f::AbstractIntervalNonlinearFunction{
145-
iip,
146-
},
147-
tspan,
148-
p = NullParameters(),
149-
problem_type = StandardNonlinearProblem();
150-
kwargs...) where {iip}
145+
iip,
146+
},
147+
tspan,
148+
p = NullParameters(),
149+
problem_type = StandardNonlinearProblem();
150+
kwargs...) where {iip}
151151
warn_paramtype(p)
152152
new{iip, typeof(tspan), typeof(p), typeof(f),
153153
typeof(kwargs), typeof(problem_type)}(f,
@@ -177,7 +177,7 @@ Define a nonlinear problem using an instance of
177177
[`IntervalNonlinearFunction`](@ref IntervalNonlinearFunction).
178178
"""
179179
function IntervalNonlinearProblem(f::AbstractIntervalNonlinearFunction, tspan,
180-
p = NullParameters(); kwargs...)
180+
p = NullParameters(); kwargs...)
181181
IntervalNonlinearProblem{isinplace(f)}(f, tspan, p; kwargs...)
182182
end
183183

@@ -233,17 +233,17 @@ page.
233233
* `p`: The parameters for the problem. Defaults to `NullParameters`.
234234
* `kwargs`: The keyword arguments passed on to the solvers.
235235
"""
236-
struct NonlinearProblem{uType, isinplace, P, F, K, PT} <:
236+
struct NonlinearProblem{uType, isinplace, P, F, K, PT <: AbstractProblemType} <:
237237
AbstractNonlinearProblem{uType, isinplace}
238238
f::F
239239
u0::uType
240240
p::P
241241
problem_type::PT
242242
kwargs::K
243243
@add_kwonly function NonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
244-
p = NullParameters(),
245-
problem_type = StandardNonlinearProblem();
246-
kwargs...) where {iip}
244+
p = NullParameters(),
245+
problem_type = StandardNonlinearProblem();
246+
kwargs...) where {iip}
247247
warn_paramtype(p)
248248
new{typeof(u0), iip, typeof(p), typeof(f),
249249
typeof(kwargs), typeof(problem_type)}(f,
@@ -364,7 +364,7 @@ struct NonlinearLeastSquaresProblem{uType, isinplace, P, F, K} <:
364364
kwargs::K
365365

366366
@add_kwonly function NonlinearLeastSquaresProblem{iip}(f::AbstractNonlinearFunction{
367-
iip}, u0, p = NullParameters(); kwargs...) where {iip}
367+
iip}, u0, p = NullParameters(); kwargs...) where {iip}
368368
warn_paramtype(p)
369369
return new{typeof(u0), iip, typeof(p), typeof(f), typeof(kwargs)}(f, u0, p, kwargs)
370370
end
@@ -383,7 +383,7 @@ Define a nonlinear least squares problem using an instance of
383383
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
384384
"""
385385
function NonlinearLeastSquaresProblem(f::AbstractNonlinearFunction, u0,
386-
p = NullParameters(); kwargs...)
386+
p = NullParameters(); kwargs...)
387387
return NonlinearLeastSquaresProblem{isinplace(f)}(f, u0, p; kwargs...)
388388
end
389389

@@ -445,8 +445,8 @@ struct IntegralProblem{isinplace, P, F, T, K} <: AbstractIntegralProblem{isinpla
445445
p::P
446446
kwargs::K
447447
@add_kwonly function IntegralProblem{iip}(f::AbstractIntegralFunction{iip}, domain,
448-
p = NullParameters();
449-
kwargs...) where {iip}
448+
p = NullParameters();
449+
kwargs...) where {iip}
450450
warn_paramtype(p)
451451
new{iip, typeof(p), typeof(f), typeof(domain), typeof(kwargs)}(f,
452452
domain, p, kwargs)
@@ -456,38 +456,40 @@ end
456456
TruncatedStacktraces.@truncate_stacktrace IntegralProblem 1 4
457457

458458
function IntegralProblem(f::AbstractIntegralFunction,
459-
domain,
460-
p = NullParameters();
461-
kwargs...)
459+
domain,
460+
p = NullParameters();
461+
kwargs...)
462462
IntegralProblem{isinplace(f)}(f, domain, p; kwargs...)
463463
end
464464

465465
function IntegralProblem(f::AbstractIntegralFunction,
466-
lb::B,
467-
ub::B,
468-
p = NullParameters();
469-
kwargs...) where {B}
466+
lb::B,
467+
ub::B,
468+
p = NullParameters();
469+
kwargs...) where {B}
470470
IntegralProblem{isinplace(f)}(f, (lb, ub), p; kwargs...)
471471
end
472472

473473
function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...)
474474
if nout !== nothing || batch !== nothing
475-
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details."
475+
@warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details."
476476
end
477477

478478
g = if isinplace(f, 3)
479479
if batch === nothing
480-
output_prototype = nout === nothing ? Array{Float64, 0}(undef) : Vector{Float64}(undef, nout)
480+
output_prototype = nout === nothing ? Array{Float64, 0}(undef) :
481+
Vector{Float64}(undef, nout)
481482
IntegralFunction(f, output_prototype)
482483
else
483-
output_prototype = nout === nothing ? Float64[] : Matrix{Float64}(undef, nout, 0)
484-
BatchIntegralFunction(f, output_prototype, max_batch=batch)
484+
output_prototype = nout === nothing ? Float64[] :
485+
Matrix{Float64}(undef, nout, 0)
486+
BatchIntegralFunction(f, output_prototype, max_batch = batch)
485487
end
486488
else
487489
if batch === nothing
488490
IntegralFunction(f)
489491
else
490-
BatchIntegralFunction(f, max_batch=batch)
492+
BatchIntegralFunction(f, max_batch = batch)
491493
end
492494
end
493495
IntegralProblem(g, args...; kwargs...)
@@ -506,7 +508,7 @@ function Base.getproperty(prob::IntegralProblem, name::Symbol)
506508
return Base.getfield(prob, name)
507509
end
508510

509-
struct QuadratureProblem end
511+
struct QuadratureProblem <: AbstractProblemType end
510512
@deprecate QuadratureProblem(args...; kwargs...) IntegralProblem(args...; kwargs...)
511513

512514
@doc doc"""
@@ -548,8 +550,8 @@ struct SampledIntegralProblem{Y, X, K} <: AbstractIntegralProblem{false}
548550
dim::Int
549551
kwargs::K
550552
@add_kwonly function SampledIntegralProblem(y::AbstractArray, x::AbstractVector;
551-
dim = ndims(y),
552-
kwargs...)
553+
dim = ndims(y),
554+
kwargs...)
553555
@assert dim<=ndims(y) "The integration dimension `dim` is larger than the number of dimensions of the integrand `y`"
554556
@assert length(x)==size(y, dim) "The integrand `y` must have the same length as the sampling points `x` along the integrated dimension."
555557
@assert axes(x, 1)==axes(y, dim) "The integrand `y` must obey the same indexing as the sampling points `x` along the integrated dimension."
@@ -612,8 +614,8 @@ They should be an `AbstractArray` matching the geometry of `u`, where `(lcons[i]
612614
are the lower and upper bounds for `cons[i]`.
613615
614616
The `f` in the `OptimizationProblem` should typically be an instance of [`OptimizationFunction`](@ref)
615-
to specify the objective function and its derivatives either by passing
616-
predefined functions for them or automatically generated using the [`ADType`](@ref).
617+
to specify the objective function and its derivatives either by passing
618+
predefined functions for them or automatically generated using the [`ADType`](@ref).
617619
618620
If `f` is a standard Julia function, it is automatically transformed into an
619621
`OptimizationFunction` with `NoAD()`, meaning the derivative functions are not
@@ -663,10 +665,10 @@ struct OptimizationProblem{iip, F, uType, P, LB, UB, I, LC, UC, S, K} <:
663665
sense::S
664666
kwargs::K
665667
@add_kwonly function OptimizationProblem{iip}(f::OptimizationFunction{iip}, u0,
666-
p = NullParameters();
667-
lb = nothing, ub = nothing, int = nothing,
668-
lcons = nothing, ucons = nothing,
669-
sense = nothing, kwargs...) where {iip}
668+
p = NullParameters();
669+
lb = nothing, ub = nothing, int = nothing,
670+
lcons = nothing, ucons = nothing,
671+
sense = nothing, kwargs...) where {iip}
670672
if xor(lb === nothing, ub === nothing)
671673
error("If any of `lb` or `ub` is provided, both must be provided.")
672674
end
@@ -688,14 +690,18 @@ function OptimizationProblem(f, args...; kwargs...)
688690
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
689691
end
690692

691-
function OptimizationFunction(f::NonlinearFunction, adtype::AbstractADType = NoAD(); kwargs...)
693+
function OptimizationFunction(f::NonlinearFunction,
694+
adtype::AbstractADType = NoAD();
695+
kwargs...)
692696
if isinplace(f)
693697
throw(ArgumentError("Converting NonlinearFunction to OptimizationFunction is not supported with in-place functions yet."))
694698
end
695699
OptimizationFunction((u, p) -> sum(abs2, f(u, p)), adtype; kwargs...)
696700
end
697701

698-
function OptimizationProblem(prob::NonlinearLeastSquaresProblem, adtype::AbstractADType = NoAD(); kwargs...)
702+
function OptimizationProblem(prob::NonlinearLeastSquaresProblem,
703+
adtype::AbstractADType = NoAD();
704+
kwargs...)
699705
if isinplace(prob)
700706
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet."))
701707
end

src/problems/bvp_problems.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
$(TYPEDEF)
33
"""
4-
struct StandardBVProblem end
4+
struct StandardBVProblem <: AbstractProblemType end
55

66
"""
77
$(TYPEDEF)
@@ -105,7 +105,7 @@ every solve call.
105105
* `p`: The parameters for the problem. Defaults to `NullParameters`
106106
* `kwargs`: The keyword arguments passed onto the solves.
107107
"""
108-
struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
108+
struct BVProblem{uType, tType, isinplace, P, F, PT <: AbstractProblemType, K} <:
109109
AbstractBVProblem{uType, tType, isinplace}
110110
f::F
111111
u0::uType
@@ -115,7 +115,7 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
115115
kwargs::K
116116

117117
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
118-
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
118+
p = NullParameters(); problem_type = nothing, kwargs...) where {iip, TP}
119119
_u0 = prepare_initial_state(u0)
120120
_tspan = promote_tspan(tspan)
121121
warn_paramtype(p)
@@ -124,10 +124,15 @@ struct BVProblem{uType, tType, isinplace, P, F, PT, K} <:
124124
if problem_type === nothing
125125
problem_type = prob_type
126126
else
127-
@assert prob_type === problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
127+
@assert prob_type===problem_type "This indicates incorrect problem type specification! Users should never pass in `problem_type` kwarg, this exists exclusively for internal use."
128128
end
129129
return new{typeof(_u0), typeof(_tspan), iip, typeof(p), typeof(f),
130-
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
130+
typeof(problem_type), typeof(kwargs)}(f,
131+
_u0,
132+
_tspan,
133+
p,
134+
problem_type,
135+
kwargs)
131136
end
132137

133138
function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}
@@ -158,34 +163,34 @@ end
158163
end
159164

160165
function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();
161-
bcresid_prototype=nothing, kwargs...) where {iip}
166+
bcresid_prototype = nothing, kwargs...) where {iip}
162167
return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan,
163168
p; kwargs...)
164169
end
165170
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
166-
bcresid_prototype=nothing, kwargs...)
171+
bcresid_prototype = nothing, kwargs...)
167172
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
168173
kwargs...)
169174
end
170175
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
171-
p = NullParameters(); kwargs...) where {iip, twopoint}
176+
p = NullParameters(); kwargs...) where {iip, twopoint}
172177
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
173178
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
174179
end
175180
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
176-
p = NullParameters(); kwargs...) where {iip, twopoint}
181+
p = NullParameters(); kwargs...) where {iip, twopoint}
177182
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=Val(true)` during the construction of the `BVPFunction`."
178183
return BVProblem{iip}(f, u0, tspan, p; kwargs...)
179184
end
180185

181186
# Allow previous timeseries solution
182187
function TwoPointBVProblem(f::AbstractODEFunction, bc, sol::T, tspan::Tuple,
183-
p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution}
188+
p = NullParameters(); kwargs...) where {T <: AbstractTimeseriesSolution}
184189
return TwoPointBVProblem(f, bc, sol.u, tspan, p; kwargs...)
185190
end
186191
# Allow initial guess function for the initial guess
187192
function TwoPointBVProblem(f::AbstractODEFunction, bc, initialGuess, tspan::AbstractVector,
188-
p = NullParameters(); kwargs...)
193+
p = NullParameters(); kwargs...)
189194
u0 = [initialGuess(i) for i in tspan]
190195
return TwoPointBVProblem(f, bc, u0, (tspan[1], tspan[end]), p; kwargs...)
191196
end

0 commit comments

Comments
 (0)