Skip to content

Commit df3fb57

Browse files
committed
fix: correct argument number and signatures
1 parent 1f14548 commit df3fb57

File tree

1 file changed

+35
-34
lines changed

1 file changed

+35
-34
lines changed

src/scimlfunctions.jl

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,24 +2097,25 @@ end
20972097
"""
20982098
$(TYPEDEF)
20992099
"""
2100-
abstract type AbstractControlFunction{iip} <: AbstractDiffEqFunction{iip} end
2100+
abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end
21012101

21022102
@doc doc"""
21032103
$(TYPEDEF)
21042104
2105-
A representation of a optimal control function `f`, defined by:
2105+
A representation of a ODE function `f` with inputs, defined by:
21062106
21072107
```math
21082108
\frac{dx}{dt} = f(x, u, p, t)
21092109
```
2110-
where `x` are the states of the system and `u` are the inputs (or control variables).
2110+
where `x` are the states of the system and `u` are the inputs (which may represent
2111+
different things in different contexts, such as control variables in optimal control).
21112112
21122113
Includes all of its related functions, such as the Jacobian of `f`, its gradient
21132114
with respect to time, and more. For all cases, `u0` is the initial condition,
21142115
`p` are the parameters, and `t` is the independent variable.
21152116
21162117
```julia
2117-
ControlFunction{iip, specialize}(f;
2118+
ODEInputFunction{iip, specialize}(f;
21182119
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
21192120
analytic = __has_analytic(f) ? f.analytic : nothing,
21202121
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
@@ -2139,11 +2140,11 @@ See the section on `iip` for more details on in-place vs out-of-place handling.
21392140
- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
21402141
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
21412142
if `M` is singular.
2142-
- `jac(J,dx,x,p,gamma,t)` or `J=jac(dx,x,p,gamma,t)`: returns ``\frac{df}{dx}``
2143-
- `control_jac(J,du,u,p,gamma,t)` or `J=control_jac(du,u,p,gamma,t)`: returns ``\frac{df}{du}``
2144-
- `jvp(Jv,v,du,u,p,gamma,t)` or `Jv=jvp(v,du,u,p,gamma,t)`: returns the directional
2143+
- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}``
2144+
- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}``
2145+
- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
21452146
derivative ``\frac{df}{du} v``
2146-
- `vjp(Jv,v,du,u,p,gamma,t)` or `Jv=vjp(v,du,u,p,gamma,t)`: returns the adjoint
2147+
- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
21472148
derivative ``\frac{df}{du}^\ast v``
21482149
- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
21492150
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
@@ -2155,7 +2156,7 @@ See the section on `iip` for more details on in-place vs out-of-place handling.
21552156
as the prototype and integrators will specialize on this structure where possible. Non-structured
21562157
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
21572158
The default is `nothing`, which means a dense Jacobian.
2158-
- `paramjac(pJ,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``.
2159+
- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``.
21592160
- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
21602161
pattern of the `jac_prototype`. This specializes the Jacobian construction when using
21612162
finite differences and automatic differentiation to be computed in an accelerated manner
@@ -2170,11 +2171,11 @@ For more details on this argument, see the ODEFunction documentation.
21702171
For more details on this argument, see the ODEFunction documentation.
21712172
21722173
## Fields
2173-
The fields of the ControlFunction type directly match the names of the inputs.
2174+
The fields of the ODEInputFunction type directly match the names of the inputs.
21742175
"""
2175-
struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2176+
struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
21762177
JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
2177-
SYS, ID} <: AbstractControlFunction{iip}
2178+
SYS, ID} <: AbstractODEInputFunction{iip}
21782179
f::F
21792180
mass_matrix::TMM
21802181
analytic::Ta
@@ -2595,7 +2596,7 @@ end
25952596
(f::ImplicitDiscreteFunction)(args...) = f.f(args...)
25962597
(f::DAEFunction)(args...) = f.f(args...)
25972598
(f::DDEFunction)(args...) = f.f(args...)
2598-
(f::ControlFunction)(args...) = f.f(args...)
2599+
(f::ODEInputFunction)(args...) = f.f(args...)
25992600

26002601
function (f::DynamicalDDEFunction)(u, h, p, t)
26012602
ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t))
@@ -4698,7 +4699,7 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
46984699
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
46994700
end
47004701

4701-
function ControlFunction{iip, specialize}(f;
4702+
function ODEInputFunction{iip, specialize}(f;
47024703
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
47034704
I,
47044705
analytic = __has_analytic(f) ? f.analytic : nothing,
@@ -4748,17 +4749,17 @@ function ControlFunction{iip, specialize}(f;
47484749

47494750
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
47504751
if iip
4751-
jac = update_coefficients! #(J,u,p,t)
4752+
jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t)
47524753
else
4753-
jac = (u, p, t) -> update_coefficients(deepcopy(jac_prototype), u, p, t)
4754+
jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t)
47544755
end
47554756
end
47564757

47574758
if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator)
47584759
if iip_bc
4759-
controljac = update_coefficients! #(J,u,p,t)
4760+
controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t)
47604761
else
4761-
controljac = (u, p, t) -> update_coefficients!(deepcopy(controljac_prototype), u, p, t)
4762+
controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t)
47624763
end
47634764
end
47644765

@@ -4769,14 +4770,14 @@ function ControlFunction{iip, specialize}(f;
47694770
_colorvec = colorvec
47704771
end
47714772

4772-
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
4773-
controljaciip = controljac !== nothing ? isinplace(controljac, 4, "controljac", iip) : iip
4774-
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
4775-
jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip
4776-
vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip
4777-
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip
4778-
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
4779-
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip
4773+
jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip
4774+
controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip
4775+
tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip
4776+
jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip
4777+
vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip
4778+
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip
4779+
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip
4780+
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip
47804781

47814782
nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
47824783
paramjaciip) .!= iip
@@ -4794,7 +4795,7 @@ function ControlFunction{iip, specialize}(f;
47944795
initializeprobmap, initializeprobpmap)
47954796

47964797
if specialize === NoSpecialize
4797-
ControlFunction{iip, specialize,
4798+
ODEInputFunction{iip, specialize,
47984799
Any, Any, Any, Any,
47994800
Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype),
48004801
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
@@ -4806,7 +4807,7 @@ function ControlFunction{iip, specialize}(f;
48064807
Wfact_t, W_prototype, paramjac,
48074808
observed, _colorvec, sys, initdata)
48084809
elseif specialize === false
4809-
ControlFunction{iip, FunctionWrapperSpecialize,
4810+
ODEInputFunction{iip, FunctionWrapperSpecialize,
48104811
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
48114812
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
48124813
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
@@ -4819,7 +4820,7 @@ function ControlFunction{iip, specialize}(f;
48194820
Wfact_t, W_prototype, paramjac,
48204821
observed, _colorvec, sys, initdata)
48214822
else
4822-
ControlFunction{iip, specialize,
4823+
ODEInputFunction{iip, specialize,
48234824
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
48244825
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
48254826
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
@@ -4834,12 +4835,12 @@ function ControlFunction{iip, specialize}(f;
48344835
end
48354836
end
48364837

4837-
function ControlFunction{iip}(f; kwargs...) where {iip}
4838-
ControlFunction{iip, FullSpecialize}(f; kwargs...)
4838+
function ODEInputFunction{iip}(f; kwargs...) where {iip}
4839+
ODEInputFunction{iip, FullSpecialize}(f; kwargs...)
48394840
end
4840-
ControlFunction{iip}(f::ControlFunction; kwargs...) where {iip} = f
4841-
ControlFunction(f; kwargs...) = ControlFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
4842-
ControlFunction(f::ControlFunction; kwargs...) = f
4841+
ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f
4842+
ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
4843+
ODEInputFunction(f::ODEInputFunction; kwargs...) = f
48434844

48444845
########## Utility functions
48454846

0 commit comments

Comments
 (0)