Skip to content

Commit 987ca87

Browse files
authored
Clean up ReverseDiff type annotations (#498)
* Remove unneeded evaluation for ReverseDiff * Undo fix
1 parent 93c0659 commit 987ca87

File tree

2 files changed

+19
-57
lines changed

2 files changed

+19
-57
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

+18-54
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ struct ReverseDiffGradientPrep{T} <: GradientPrep
4747
tape::T
4848
end
4949

50-
function DI.prepare_gradient(
51-
f, ::AutoReverseDiff{Compile}, x::AbstractArray
52-
) where {Compile}
50+
function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile}
5351
tape = GradientTape(f, x)
5452
if Compile
5553
tape = compile(tape)
@@ -58,11 +56,7 @@ function DI.prepare_gradient(
5856
end
5957

6058
function DI.value_and_gradient!(
61-
f,
62-
grad::AbstractArray,
63-
prep::ReverseDiffGradientPrep,
64-
::AutoReverseDiff,
65-
x::AbstractArray,
59+
f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
6660
)
6761
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
6862
result = MutableDiffResult(y, (grad,))
@@ -71,23 +65,19 @@ function DI.value_and_gradient!(
7165
end
7266

7367
function DI.value_and_gradient(
74-
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x::AbstractArray
68+
f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x
7569
)
7670
grad = similar(x)
7771
return DI.value_and_gradient!(f, grad, prep, backend, x)
7872
end
7973

8074
function DI.gradient!(
81-
_f,
82-
grad::AbstractArray,
83-
prep::ReverseDiffGradientPrep,
84-
::AutoReverseDiff,
85-
x::AbstractArray,
75+
_f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray
8676
)
8777
return gradient!(grad, prep.tape, x)
8878
end
8979

90-
function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray)
80+
function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x)
9181
return gradient!(prep.tape, x)
9282
end
9383

@@ -97,9 +87,7 @@ struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
9787
tape::T
9888
end
9989

100-
function DI.prepare_jacobian(
101-
f, ::AutoReverseDiff{Compile}, x::AbstractArray
102-
) where {Compile}
90+
function DI.prepare_jacobian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
10391
tape = JacobianTape(f, x)
10492
if Compile
10593
tape = compile(tape)
@@ -108,37 +96,23 @@ function DI.prepare_jacobian(
10896
end
10997

11098
function DI.value_and_jacobian!(
111-
f,
112-
jac::AbstractMatrix,
113-
prep::ReverseDiffOneArgJacobianPrep,
114-
::AutoReverseDiff,
115-
x::AbstractArray,
99+
f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x
116100
)
117101
y = f(x)
118102
result = MutableDiffResult(y, (jac,))
119103
result = jacobian!(result, prep.tape, x)
120104
return DiffResults.value(result), DiffResults.derivative(result)
121105
end
122106

123-
function DI.value_and_jacobian(
124-
f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray
125-
)
107+
function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
126108
return f(x), jacobian!(prep.tape, x)
127109
end
128110

129-
function DI.jacobian!(
130-
_f,
131-
jac::AbstractMatrix,
132-
prep::ReverseDiffOneArgJacobianPrep,
133-
::AutoReverseDiff,
134-
x::AbstractArray,
135-
)
111+
function DI.jacobian!(_f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
136112
return jacobian!(jac, prep.tape, x)
137113
end
138114

139-
function DI.jacobian(
140-
f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray
141-
)
115+
function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
142116
return jacobian!(prep.tape, x)
143117
end
144118

@@ -148,35 +122,24 @@ struct ReverseDiffHessianPrep{T} <: HessianPrep
148122
tape::T
149123
end
150124

151-
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile}
125+
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
152126
tape = HessianTape(f, x)
153127
if Compile
154128
tape = compile(tape)
155129
end
156130
return ReverseDiffHessianPrep(tape)
157131
end
158132

159-
function DI.hessian!(
160-
_f,
161-
hess::AbstractMatrix,
162-
prep::ReverseDiffHessianPrep,
163-
::AutoReverseDiff,
164-
x::AbstractArray,
165-
)
133+
function DI.hessian!(_f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
166134
return hessian!(hess, prep.tape, x)
167135
end
168136

169-
function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray)
137+
function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
170138
return hessian!(prep.tape, x)
171139
end
172140

173141
function DI.value_gradient_and_hessian!(
174-
f,
175-
grad,
176-
hess::AbstractMatrix,
177-
prep::ReverseDiffHessianPrep,
178-
::AutoReverseDiff,
179-
x::AbstractArray,
142+
f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
180143
)
181144
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
182145
result = MutableDiffResult(y, (grad, hess))
@@ -187,10 +150,11 @@ function DI.value_gradient_and_hessian!(
187150
end
188151

189152
function DI.value_gradient_and_hessian(
190-
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray
153+
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
191154
)
192-
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
193-
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
155+
result = MutableDiffResult(
156+
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
157+
)
194158
result = hessian!(result, prep.tape, x)
195159
return (
196160
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ struct ReverseDiffTwoArgJacobianPrep{T} <: JacobianPrep
8181
tape::T
8282
end
8383

84-
function DI.prepare_jacobian(
85-
f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray
86-
) where {Compile}
84+
function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{Compile}, x) where {Compile}
8785
tape = JacobianTape(f!, y, x)
8886
if Compile
8987
tape = compile(tape)

0 commit comments

Comments
 (0)