Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 77a337a

Browse files
committed
Allow a kwarg to force use the correct ordering
1 parent 8254f9e commit 77a337a

File tree

3 files changed

+42
-25
lines changed

3 files changed

+42
-25
lines changed

src/differentiation/common.jl

+37-21
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,37 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop
3838
(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
3939
(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u)
4040

41-
function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
41+
# NOTE: `use_deprecated_ordering` is a way for external libraries to update to the correct
42+
# style. In the next release, we will drop the first check
43+
function JacFunctionWrapper(f::F, fu_, u, p, t;
44+
use_deprecated_ordering::Val{deporder} = Val(true)) where {F, deporder}
4245
# The warning instead of error ensures a non-breaking change for users relying on an
4346
# undefined / undocumented feature
4447
fu = fu_ === nothing ? copy(u) : copy(fu_)
4548

46-
# Check this first else we were breaking things
47-
# In the next breaking release, we will fix the ordering of the checks
48-
iip = static_hasmethod(f, typeof((fu, u)))
49-
oop = static_hasmethod(f, typeof((u,)))
50-
if iip || oop
51-
if p !== nothing || t !== nothing
52-
Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we
53-
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
49+
if deporder
50+
# Check this first else we were breaking things
51+
# In the next breaking release, we will fix the ordering of the checks
52+
iip = static_hasmethod(f, typeof((fu, u)))
53+
oop = static_hasmethod(f, typeof((u,)))
54+
if iip || oop
55+
if p !== nothing || t !== nothing
56+
Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we
57+
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
5458
55-
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not be
56-
supplied.
57-
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
59+
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not
60+
be supplied.
61+
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
5862
59-
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in future
60-
breaking releases, this check will be performed last, which means that if `t`
61-
is provided `f(du, u, p, t)`/`f(u, p, t)` will be given precedence, similarly
62-
if `p` is provided `f(du, u, p)`/`f(u, p)` will be given precedence.""",
63-
:JacFunctionWrapper)
63+
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in
64+
future breaking releases, this check will be performed last, which means
65+
that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given
66+
precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be
67+
given precedence.""", :JacFunctionWrapper)
68+
end
69+
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
70+
fu, p, t)
6471
end
65-
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
66-
fu, p, t)
6772
end
6873

6974
if t !== nothing
@@ -86,6 +91,17 @@ function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
8691
fu, p, t)
8792
end
8893

89-
throw(ArgumentError("""Couldn't determine the function signature of `f` to construct a
90-
JacobianWrapper!"""))
94+
if !deporder
95+
iip = static_hasmethod(f, typeof((fu, u)))
96+
oop = static_hasmethod(f, typeof((u,)))
97+
if !iip && !oop
98+
throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for
99+
`f`!"""))
100+
end
101+
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
102+
fu, p, t)
103+
else
104+
throw(ArgumentError("""Couldn't determine the function signature of `f` to
105+
construct a JacobianWrapper!"""))
106+
end
91107
end

src/differentiation/jaches_products.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ f(du, u) # Otherwise
263263
```
264264
"""
265265
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
266-
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
267-
ff = JacFunctionWrapper(f, fu, u, p, t)
266+
autodiff = AutoForwardDiff(), tag = DeivVecTag(),
267+
use_deprecated_ordering::Val = Val(true), kwargs...)
268+
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)
268269
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)
269270

270271
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff

src/differentiation/vecjac_products.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ f(du, u) # Otherwise
7272
```
7373
"""
7474
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
75-
autodiff = AutoFiniteDiff(), kwargs...)
76-
ff = JacFunctionWrapper(f, fu, u, p, t)
75+
autodiff = AutoFiniteDiff(), use_deprecated_ordering::Val = Val(true), kwargs...)
76+
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)
7777

7878
if !__internal_oop(ff) && autodiff isa AutoZygote
7979
msg = "Zygote requires an out of place method with signature f(u)."

0 commit comments

Comments
 (0)