Skip to content

Commit 1cd085f

Browse files
authored
fix performance bug in Hessian computation (#5)
1 parent 94facf3 commit 1cd085f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/forward_over_reverse.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,19 @@ function _eval_hessian_chunk(
9292
for s in 1:chunk
9393
# If `chunk < chunk_size`, leaves junk in the unused components
9494
d.input_ϵ[(idx-1)*chunk_size+s] = ex.seed_matrix[r, offset+s-1]
95+
# Ensure the output is clear in preparation for the chunk
96+
d.output_ϵ[(idx-1)*chunk_size+s] = 0.0
9597
end
9698
end
9799
_hessian_slice_inner(d, ex, chunk_size)
98-
fill!(d.input_ϵ, 0.0)
99100
# collect directional derivatives
100101
for r in eachindex(ex.rinfo.local_indices)
101102
@inbounds idx = ex.rinfo.local_indices[r]
102103
# load output_ϵ into ex.seed_matrix[r,k,k+1,...,k+remaining-1]
103104
for s in 1:chunk
104105
ex.seed_matrix[r, offset+s-1] = d.output_ϵ[(idx-1)*chunk_size+s]
106+
# Reset the input in preparation for the next chunk
107+
d.input_ϵ[(idx-1)*chunk_size+s] = 0.0
105108
end
106109
end
107110
return
@@ -122,7 +125,6 @@ end
122125
end
123126

124127
function _hessian_slice_inner(d, ex, ::Type{T}) where {T}
125-
fill!(d.output_ϵ, 0.0)
126128
output_ϵ = _reinterpret_unsafe(T, d.output_ϵ)
127129
subexpr_forward_values_ϵ =
128130
_reinterpret_unsafe(T, d.subexpression_forward_values_ϵ)

0 commit comments

Comments
 (0)