Skip to content

Commit 2bd236f

Browse files
authored
[Nonlinear] Merge forward_storage_ϵ with reverse_storage_ϵ (#2732)
1 parent e28ecf9 commit 2bd236f

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
126126
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
127127
d,
128128
subexpr,
129-
_reinterpret_unsafe(T, subexpr.forward_storage_ϵ),
129+
_reinterpret_unsafe(T, d.storage_ϵ),
130130
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
131131
input_ϵ,
132132
subexpr_forward_values_ϵ,
@@ -136,7 +136,7 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
136136
_forward_eval_ϵ(
137137
d,
138138
ex,
139-
_reinterpret_unsafe(T, d.forward_storage_ϵ),
139+
_reinterpret_unsafe(T, d.storage_ϵ),
140140
_reinterpret_unsafe(T, d.partials_storage_ϵ),
141141
input_ϵ,
142142
subexpr_forward_values_ϵ,
@@ -152,7 +152,7 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
152152
_reverse_eval_ϵ(
153153
output_ϵ,
154154
ex,
155-
_reinterpret_unsafe(T, d.reverse_storage_ϵ),
155+
_reinterpret_unsafe(T, d.storage_ϵ),
156156
_reinterpret_unsafe(T, d.partials_storage_ϵ),
157157
d.subexpression_reverse_values,
158158
subexpr_reverse_values_ϵ,
@@ -165,7 +165,7 @@ function _hessian_slice_inner(d, ex, ::Val{CHUNK}) where {CHUNK}
165165
_reverse_eval_ϵ(
166166
output_ϵ,
167167
subexpr,
168-
_reinterpret_unsafe(T, subexpr.reverse_storage_ϵ),
168+
_reinterpret_unsafe(T, d.storage_ϵ),
169169
_reinterpret_unsafe(T, subexpr.partials_storage_ϵ),
170170
d.subexpression_reverse_values,
171171
subexpr_reverse_values_ϵ,

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
4444
want_hess_storage = (:HessVec in requested_features) || d.want_hess
4545
coloring_storage = Coloring.IndexedSet(N)
4646
max_expr_length = 0
47+
max_expr_with_sub_length = 0
4748
#
4849
main_expressions = [c.expression.nodes for (_, c) in d.data.constraints]
4950
if d.data.objective !== nothing
@@ -71,6 +72,8 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
7172
)
7273
d.subexpressions[k] = subex
7374
d.subexpression_linearity[k] = subex.linearity
75+
max_expr_with_sub_length =
76+
max(max_expr_with_sub_length, length(subex.nodes))
7477
if d.want_hess
7578
empty!(coloring_storage)
7679
_compute_gradient_sparsity!(coloring_storage, subex.nodes)
@@ -138,27 +141,22 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
138141
end
139142
# 10 is hardcoded upper bound to avoid excess memory allocation
140143
max_chunk = min(max_chunk, 10)
144+
max_expr_with_sub_length = max(max_expr_with_sub_length, max_expr_length)
141145
if d.want_hess || want_hess_storage
142146
d.input_ϵ = zeros(max_chunk * N)
143147
d.output_ϵ = zeros(max_chunk * N)
144148
#
145-
len = max_chunk * max_expr_length
146-
d.forward_storage_ϵ = zeros(len)
147-
d.partials_storage_ϵ = zeros(len)
148-
d.reverse_storage_ϵ = zeros(len)
149+
d.partials_storage_ϵ = zeros(max_chunk * max_expr_length)
150+
d.storage_ϵ = zeros(max_chunk * max_expr_with_sub_length)
149151
#
150152
len = max_chunk * length(d.subexpressions)
151153
d.subexpression_forward_values_ϵ = zeros(len)
152154
d.subexpression_reverse_values_ϵ = zeros(len)
153155
#
154156
for k in d.subexpression_order
155157
len = max_chunk * length(d.subexpressions[k].nodes)
156-
resize!(d.subexpressions[k].forward_storage_ϵ, len)
157-
fill!(d.subexpressions[k].forward_storage_ϵ, 0.0)
158158
resize!(d.subexpressions[k].partials_storage_ϵ, len)
159159
fill!(d.subexpressions[k].partials_storage_ϵ, 0.0)
160-
resize!(d.subexpressions[k].reverse_storage_ϵ, len)
161-
fill!(d.subexpressions[k].reverse_storage_ϵ, 0.0)
162160
end
163161
d.max_chunk = max_chunk
164162
if d.want_hess
@@ -350,7 +348,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
350348
subexpr_forward_values_ϵ[i] = _forward_eval_ϵ(
351349
d,
352350
subexpr,
353-
reinterpret(T, subexpr.forward_storage_ϵ),
351+
reinterpret(T, d.storage_ϵ),
354352
reinterpret(T, subexpr.partials_storage_ϵ),
355353
input_ϵ,
356354
subexpr_forward_values_ϵ,
@@ -361,13 +359,13 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
361359
subexpr_reverse_values_ϵ = reinterpret(T, d.subexpression_reverse_values_ϵ)
362360
fill!(subexpr_reverse_values_ϵ, zero(T))
363361
fill!(d.subexpression_reverse_values, 0.0)
364-
fill!(d.reverse_storage_ϵ, 0.0)
362+
fill!(d.storage_ϵ, 0.0)
365363
fill!(output_ϵ, zero(T))
366364
if d.objective !== nothing
367365
_forward_eval_ϵ(
368366
d,
369367
something(d.objective),
370-
reinterpret(T, d.forward_storage_ϵ),
368+
reinterpret(T, d.storage_ϵ),
371369
reinterpret(T, d.partials_storage_ϵ),
372370
input_ϵ,
373371
subexpr_forward_values_ϵ,
@@ -376,7 +374,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
376374
_reverse_eval_ϵ(
377375
output_ϵ,
378376
something(d.objective),
379-
reinterpret(T, d.reverse_storage_ϵ),
377+
reinterpret(T, d.storage_ϵ),
380378
reinterpret(T, d.partials_storage_ϵ),
381379
d.subexpression_reverse_values,
382380
subexpr_reverse_values_ϵ,
@@ -388,7 +386,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
388386
_forward_eval_ϵ(
389387
d,
390388
con,
391-
reinterpret(T, d.forward_storage_ϵ),
389+
reinterpret(T, d.storage_ϵ),
392390
reinterpret(T, d.partials_storage_ϵ),
393391
input_ϵ,
394392
subexpr_forward_values_ϵ,
@@ -397,7 +395,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
397395
_reverse_eval_ϵ(
398396
output_ϵ,
399397
con,
400-
reinterpret(T, d.reverse_storage_ϵ),
398+
reinterpret(T, d.storage_ϵ),
401399
reinterpret(T, d.partials_storage_ϵ),
402400
d.subexpression_reverse_values,
403401
subexpr_reverse_values_ϵ,
@@ -411,7 +409,7 @@ function MOI.eval_hessian_lagrangian_product(d::NLPEvaluator, h, x, v, σ, μ)
411409
_reverse_eval_ϵ(
412410
output_ϵ,
413411
subexpr,
414-
reinterpret(T, subexpr.reverse_storage_ϵ),
412+
reinterpret(T, d.storage_ϵ),
415413
reinterpret(T, subexpr.partials_storage_ϵ),
416414
d.subexpression_reverse_values,
417415
subexpr_reverse_values_ϵ,

src/Nonlinear/ReverseAD/types.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ struct _SubexpressionStorage
1111
forward_storage::Vector{Float64}
1212
partials_storage::Vector{Float64}
1313
reverse_storage::Vector{Float64}
14-
forward_storage_ϵ::Vector{Float64}
1514
partials_storage_ϵ::Vector{Float64}
16-
reverse_storage_ϵ::Vector{Float64}
1715
linearity::Linearity
1816

1917
function _SubexpressionStorage(
@@ -34,8 +32,6 @@ struct _SubexpressionStorage
3432
zeros(N), # partials_storage,
3533
zeros(N), # reverse_storage,
3634
Float64[],
37-
Float64[],
38-
Float64[],
3935
linearity[1],
4036
)
4137
end
@@ -175,11 +171,10 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
175171
# so the length should be multiplied by the maximum number of epsilon components
176172
disable_2ndorder::Bool # don't offer Hess or HessVec
177173
want_hess::Bool
178-
forward_storage_ϵ::Vector{Float64} # (longest expression)
179-
partials_storage_ϵ::Vector{Float64} # (longest expression)
180-
reverse_storage_ϵ::Vector{Float64} # (longest expression)
174+
partials_storage_ϵ::Vector{Float64} # (longest expression excluding subexpressions)
175+
storage_ϵ::Vector{Float64} # (longest expression including subexpressions)
181176
input_ϵ::Vector{Float64} # (number of variables)
182-
output_ϵ::Vector{Float64}# (number of variables)
177+
output_ϵ::Vector{Float64} # (number of variables)
183178
subexpression_forward_values_ϵ::Vector{Float64} # (number of subexpressions)
184179
subexpression_reverse_values_ϵ::Vector{Float64} # (number of subexpressions)
185180
hessian_sparsity::Vector{Tuple{Int64,Int64}}

0 commit comments

Comments
 (0)