-
-
Notifications
You must be signed in to change notification settings - Fork 104
fix isinplace
inference and add inference tests
#1019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,157 +42,6 @@ function num_types_in_tuple(sig::UnionAll) | |
length(Base.unwrap_unionall(sig).parameters) | ||
end | ||
|
||
const NO_METHODS_ERROR_MESSAGE = """ | ||
No methods were found for the model function passed to the equation solver. | ||
The function `f` needs to have dispatches, for example, for an ODEProblem | ||
`f` must define either `f(u,p,t)` or `f(du,u,p,t)`. For more information | ||
on how the model function `f` should be defined, consult the docstring for | ||
the appropriate `AbstractSciMLFunction`. | ||
""" | ||
|
||
struct NoMethodsError <: Exception | ||
fname::String | ||
end | ||
|
||
function Base.showerror(io::IO, e::NoMethodsError) | ||
println(io, NO_METHODS_ERROR_MESSAGE) | ||
print(io, "Offending function: ") | ||
printstyled(io, e.fname; bold = true, color = :red) | ||
end | ||
|
||
const TOO_MANY_ARGUMENTS_ERROR_MESSAGE = """ | ||
All methods for the model function `f` had too many arguments. For example, | ||
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error | ||
can be thrown if you define an ODE model for example as `f(du,u,p1,p2,t)`. | ||
For more information on the required number of arguments for the function | ||
you were defining, consult the documentation for the `SciMLProblem` or | ||
`SciMLFunction` type that was being constructed. | ||
|
||
A common reason for this occurrence is due to following the MATLAB or SciPy | ||
convention for parameter passing, i.e. to add each parameter as an argument. | ||
In the SciML convention, if you wish to pass multiple parameters, use a | ||
struct or other collection to hold the parameters. For example, here is the | ||
parameterized Lorenz equation: | ||
|
||
```julia | ||
function lorenz(du,u,p,t) | ||
du[1] = p[1]*(u[2]-u[1]) | ||
du[2] = u[1]*(p[2]-u[3]) - u[2] | ||
du[3] = u[1]*u[2] - p[3]*u[3] | ||
end | ||
u0 = [1.0;0.0;0.0] | ||
p = [10.0,28.0,8/3] | ||
tspan = (0.0,100.0) | ||
prob = ODEProblem(lorenz,u0,tspan,p) | ||
``` | ||
|
||
Notice that `f` is defined with a single `p`, an array which matches the definition | ||
of the `p` in the `ODEProblem`. Note that `p` can be any Julia struct. | ||
""" | ||
|
||
struct TooManyArgumentsError <: Exception | ||
fname::String | ||
f::Any | ||
end | ||
|
||
function Base.showerror(io::IO, e::TooManyArgumentsError) | ||
println(io, TOO_MANY_ARGUMENTS_ERROR_MESSAGE) | ||
print(io, "Offending function: ") | ||
printstyled(io, e.fname; bold = true, color = :red) | ||
println(io, "\nMethods:") | ||
println(io, methods(e.f)) | ||
end | ||
|
||
const TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION = """ | ||
All methods for the model function `f` had too few arguments. For example, | ||
an OptimizationProblem `f` must define `f(u,p)` where `u` is the optimization | ||
state and `p` are the parameters of the optimization (commonly, the hyperparameters | ||
of the simulation). | ||
|
||
A common reason for this error is from defining a single-input loss function | ||
`f(u)`. While parameters are not required, a loss function which takes parameters | ||
is required, i.e. `f(u,p)`. If you have a function `f(u)`, ignored parameters | ||
can be easily added using a closure, i.e. `OptimizationProblem((u,_)->f(u),...)`. | ||
|
||
For example, here is a parameterized optimization problem: | ||
|
||
```julia | ||
using Optimization, OptimizationOptimJL | ||
rosenbrock(u,p) = (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2 | ||
u0 = zeros(2) | ||
p = [1.0,100.0] | ||
|
||
prob = OptimizationProblem(rosenbrock,u0,p) | ||
sol = solve(prob,NelderMead()) | ||
``` | ||
|
||
and a parameter-less example: | ||
|
||
```julia | ||
using Optimization, OptimizationOptimJL | ||
rosenbrock(u,p) = (1 - u[1])^2 + (u[2] - u[1]^2)^2 | ||
u0 = zeros(2) | ||
|
||
prob = OptimizationProblem(rosenbrock,u0) | ||
sol = solve(prob,NelderMead()) | ||
``` | ||
""" | ||
|
||
const TOO_FEW_ARGUMENTS_ERROR_MESSAGE = """ | ||
All methods for the model function `f` had too few arguments. For example, | ||
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error | ||
can be thrown if you define an ODE model for example as `f(u,t)`. The parameters | ||
`p` are not optional in the definition of `f`! For more information on the required | ||
number of arguments for the function you were defining, consult the documentation | ||
for the `SciMLProblem` or `SciMLFunction` type that was being constructed. | ||
|
||
For example, here is the no parameter Lorenz equation. The two valid versions | ||
are out of place: | ||
|
||
```julia | ||
function lorenz(u,p,t) | ||
du1 = 10.0*(u[2]-u[1]) | ||
du2 = u[1]*(28.0-u[3]) - u[2] | ||
du3 = u[1]*u[2] - 8/3*u[3] | ||
[du1,du2,du3] | ||
end | ||
u0 = [1.0;0.0;0.0] | ||
tspan = (0.0,100.0) | ||
prob = ODEProblem(lorenz,u0,tspan) | ||
``` | ||
|
||
and in-place: | ||
|
||
```julia | ||
function lorenz!(du,u,p,t) | ||
du[1] = 10.0*(u[2]-u[1]) | ||
du[2] = u[1]*(28.0-u[3]) - u[2] | ||
du[3] = u[1]*u[2] - 8/3*u[3] | ||
end | ||
u0 = [1.0;0.0;0.0] | ||
tspan = (0.0,100.0) | ||
prob = ODEProblem(lorenz!,u0,tspan) | ||
``` | ||
""" | ||
|
||
struct TooFewArgumentsError <: Exception | ||
fname::String | ||
f::Any | ||
isoptimization::Bool | ||
end | ||
|
||
function Base.showerror(io::IO, e::TooFewArgumentsError) | ||
if e.isoptimization | ||
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION) | ||
else | ||
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE) | ||
end | ||
print(io, "Offending function: ") | ||
printstyled(io, e.fname; bold = true, color = :red) | ||
println(io, "\nMethods:") | ||
println(io, methods(e.f)) | ||
end | ||
|
||
const ARGUMENTS_ERROR_MESSAGE = """ | ||
Methods dispatches for the model function `f` do not match the required number. | ||
For example, an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. | ||
|
@@ -207,6 +56,12 @@ struct FunctionArgumentsError <: Exception | |
f::Any | ||
end | ||
|
||
# backward compat in case anyone is using these. | ||
# TODO: remove at next major version | ||
const TooManyArgumentsError = FunctionArgumentsError | ||
const TooFewArgumentsError = FunctionArgumentsError | ||
const NoMethodsError = FunctionArgumentsError | ||
|
||
function Base.showerror(io::IO, e::FunctionArgumentsError) | ||
println(io, ARGUMENTS_ERROR_MESSAGE) | ||
print(io, "Offending function: ") | ||
|
@@ -246,66 +101,14 @@ form is disabled and the 2-argument signature is ensured to be matched. | |
function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true; | ||
has_two_dispatches = true, isoptimization = false, | ||
outofplace_param_number = inplace_param_number - 1) | ||
nargs = numargs(f) | ||
iip_dispatch = any(x -> x == inplace_param_number, nargs) | ||
oop_dispatch = any(x -> x == outofplace_param_number, nargs) | ||
|
||
if length(nargs) == 0 | ||
throw(NoMethodsError(fname)) | ||
end | ||
|
||
if !iip_dispatch && !oop_dispatch && !isoptimization | ||
if all(>(inplace_param_number), nargs) | ||
throw(TooManyArgumentsError(fname, f)) | ||
elseif all(<(outofplace_param_number), nargs) && has_two_dispatches | ||
# Possible extra safety? | ||
# Find if there's a `f(args...)` dispatch | ||
# If so, no error | ||
_parameters = if methods(f).ms[1].sig isa UnionAll | ||
Base.unwrap_unionall(methods(f).ms[1].sig).parameters | ||
else | ||
methods(f).ms[1].sig.parameters | ||
end | ||
|
||
for i in 1:length(nargs) | ||
if nargs[i] < inplace_param_number && | ||
any(isequal(Vararg{Any}), _parameters) | ||
# If varargs, assume iip | ||
return iip_preferred | ||
end | ||
end | ||
|
||
# No varargs detected, error that there are dispatches but not the right ones | ||
|
||
throw(TooFewArgumentsError(fname, f, isoptimization)) | ||
else | ||
throw(FunctionArgumentsError(fname, f)) | ||
end | ||
elseif oop_dispatch && !iip_dispatch && !has_two_dispatches | ||
|
||
# Possible extra safety? | ||
# Find if there's a `f(args...)` dispatch | ||
# If so, no error | ||
for i in 1:length(nargs) | ||
if nargs[i] < inplace_param_number && | ||
any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters) | ||
# If varargs, assume iip | ||
return iip_preferred | ||
end | ||
end | ||
|
||
throw(TooFewArgumentsError(fname, f, isoptimization)) | ||
if iip_preferred | ||
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true | ||
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false | ||
else | ||
if iip_preferred | ||
# Equivalent to, if iip_dispatch exists, treat as iip | ||
# Otherwise, it's oop | ||
iip_dispatch | ||
else | ||
# Equivalent to, if oop_dispatch exists, treat as oop | ||
# Otherwise, it's iip | ||
!oop_dispatch | ||
end | ||
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false | ||
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true | ||
end | ||
throw(FunctionArgumentsError(fname, f)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This error message loses a lot of information. Can we in the error path do a method check and throw the more informative error message? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO the better method is to just put the expected |
||
end | ||
|
||
isinplace(f::AbstractSciMLFunction{iip}) where {iip} = iip | ||
|
Uh oh!
There was an error while loading. Please reload this page.