@@ -47,9 +47,7 @@ struct ReverseDiffGradientPrep{T} <: GradientPrep
47
47
tape:: T
48
48
end
49
49
50
- function DI. prepare_gradient (
51
- f, :: AutoReverseDiff{Compile} , x:: AbstractArray
52
- ) where {Compile}
50
+ function DI. prepare_gradient (f, :: AutoReverseDiff{Compile} , x) where {Compile}
53
51
tape = GradientTape (f, x)
54
52
if Compile
55
53
tape = compile (tape)
@@ -58,11 +56,7 @@ function DI.prepare_gradient(
58
56
end
59
57
60
58
function DI. value_and_gradient! (
61
- f,
62
- grad:: AbstractArray ,
63
- prep:: ReverseDiffGradientPrep ,
64
- :: AutoReverseDiff ,
65
- x:: AbstractArray ,
59
+ f, grad:: AbstractArray , prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x
66
60
)
67
61
y = f (x) # TODO : remove once ReverseDiff#251 is fixed
68
62
result = MutableDiffResult (y, (grad,))
@@ -71,23 +65,19 @@ function DI.value_and_gradient!(
71
65
end
72
66
73
67
function DI. value_and_gradient (
74
- f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x:: AbstractArray
68
+ f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x
75
69
)
76
70
grad = similar (x)
77
71
return DI. value_and_gradient! (f, grad, prep, backend, x)
78
72
end
79
73
80
74
function DI. gradient! (
81
- _f,
82
- grad:: AbstractArray ,
83
- prep:: ReverseDiffGradientPrep ,
84
- :: AutoReverseDiff ,
85
- x:: AbstractArray ,
75
+ _f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x:: AbstractArray
86
76
)
87
77
return gradient! (grad, prep. tape, x)
88
78
end
89
79
90
- function DI. gradient (_f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x:: AbstractArray )
80
+ function DI. gradient (_f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x)
91
81
return gradient! (prep. tape, x)
92
82
end
93
83
@@ -97,9 +87,7 @@ struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
97
87
tape:: T
98
88
end
99
89
100
- function DI. prepare_jacobian (
101
- f, :: AutoReverseDiff{Compile} , x:: AbstractArray
102
- ) where {Compile}
90
+ function DI. prepare_jacobian (f, :: AutoReverseDiff{Compile} , x) where {Compile}
103
91
tape = JacobianTape (f, x)
104
92
if Compile
105
93
tape = compile (tape)
@@ -108,37 +96,23 @@ function DI.prepare_jacobian(
108
96
end
109
97
110
98
function DI. value_and_jacobian! (
111
- f,
112
- jac:: AbstractMatrix ,
113
- prep:: ReverseDiffOneArgJacobianPrep ,
114
- :: AutoReverseDiff ,
115
- x:: AbstractArray ,
99
+ f, jac, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x
116
100
)
117
101
y = f (x)
118
102
result = MutableDiffResult (y, (jac,))
119
103
result = jacobian! (result, prep. tape, x)
120
104
return DiffResults. value (result), DiffResults. derivative (result)
121
105
end
122
106
123
- function DI. value_and_jacobian (
124
- f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x:: AbstractArray
125
- )
107
+ function DI. value_and_jacobian (f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
126
108
return f (x), jacobian! (prep. tape, x)
127
109
end
128
110
129
- function DI. jacobian! (
130
- _f,
131
- jac:: AbstractMatrix ,
132
- prep:: ReverseDiffOneArgJacobianPrep ,
133
- :: AutoReverseDiff ,
134
- x:: AbstractArray ,
135
- )
111
+ function DI. jacobian! (_f, jac, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
136
112
return jacobian! (jac, prep. tape, x)
137
113
end
138
114
139
- function DI. jacobian (
140
- f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x:: AbstractArray
141
- )
115
+ function DI. jacobian (f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
142
116
return jacobian! (prep. tape, x)
143
117
end
144
118
@@ -148,35 +122,24 @@ struct ReverseDiffHessianPrep{T} <: HessianPrep
148
122
tape:: T
149
123
end
150
124
151
- function DI. prepare_hessian (f, :: AutoReverseDiff{Compile} , x:: AbstractArray ) where {Compile}
125
+ function DI. prepare_hessian (f, :: AutoReverseDiff{Compile} , x) where {Compile}
152
126
tape = HessianTape (f, x)
153
127
if Compile
154
128
tape = compile (tape)
155
129
end
156
130
return ReverseDiffHessianPrep (tape)
157
131
end
158
132
159
- function DI. hessian! (
160
- _f,
161
- hess:: AbstractMatrix ,
162
- prep:: ReverseDiffHessianPrep ,
163
- :: AutoReverseDiff ,
164
- x:: AbstractArray ,
165
- )
133
+ function DI. hessian! (_f, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
166
134
return hessian! (hess, prep. tape, x)
167
135
end
168
136
169
- function DI. hessian (_f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x:: AbstractArray )
137
+ function DI. hessian (_f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
170
138
return hessian! (prep. tape, x)
171
139
end
172
140
173
141
function DI. value_gradient_and_hessian! (
174
- f,
175
- grad,
176
- hess:: AbstractMatrix ,
177
- prep:: ReverseDiffHessianPrep ,
178
- :: AutoReverseDiff ,
179
- x:: AbstractArray ,
142
+ f, grad, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
180
143
)
181
144
y = f (x) # TODO : remove once ReverseDiff#251 is fixed
182
145
result = MutableDiffResult (y, (grad, hess))
@@ -187,10 +150,11 @@ function DI.value_gradient_and_hessian!(
187
150
end
188
151
189
152
function DI. value_gradient_and_hessian (
190
- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x:: AbstractArray
153
+ f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
191
154
)
192
- y = f (x) # TODO : remove once ReverseDiff#251 is fixed
193
- result = MutableDiffResult (y, (similar (x), similar (x, length (x), length (x))))
155
+ result = MutableDiffResult (
156
+ one (eltype (x)), (similar (x), similar (x, length (x), length (x)))
157
+ )
194
158
result = hessian! (result, prep. tape, x)
195
159
return (
196
160
DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
0 commit comments