Skip to content

Commit 0cf9fc8

Browse files
committed
solve iterations
1 parent ed757cb commit 0cf9fc8

File tree

3 files changed

+24
-58
lines changed

3 files changed

+24
-58
lines changed

examples/test.jl

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/data.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ struct SolverData{T}
193193
α::Vector{T} # step length
194194
status::Vector{Bool} # solver status
195195

196-
cache::Dict{Symbol,Vector{T}} # solver stats
196+
iter::Vector{Int}
197+
198+
cache::Dict{Symbol,Vector{T}} # solver stats
197199
end
198200

199201
function solver_data(model::Model{T}; max_cache=1000) where T
@@ -216,11 +218,23 @@ function solver_data(model::Model{T}; max_cache=1000) where T
216218
α = [1.0]
217219
gradient = zeros(num_var(model))
218220
cache = Dict(:obj => zeros(max_cache),
219-
:gradient => zeros(max_cache),
221+
:grad => zeros(max_cache),
220222
:c_max => zeros(max_cache),
221223
=> zeros(max_cache))
222224

223-
SolverData(obj, gradient, c_max, idx_x, idx_u, α, [false], cache)
225+
SolverData(obj, gradient, c_max, idx_x, idx_u, α, [false], [0], cache)
226+
end
227+
228+
function reset!(data::SolverData)
229+
fill!(data.obj, 0.0)
230+
fill!(data.gradient, 0.0)
231+
fill!(data.c_max, 0.0)
232+
fill!(data.cache[:obj], 0.0)
233+
fill!(data.cache[:grad], 0.0)
234+
fill!(data.cache[:c_max], 0.0)
235+
fill!(data.cache[], 0.0)
236+
data.status[1] = false
237+
data.iter[1] = 0
224238
end
225239

226240
# TODO: fix iter

src/solve.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ function ilqr_solve!(prob::ProblemData;
44
grad_tol=1.0e-3,
55
α_min=1.0e-5,
66
linesearch=:armijo,
7+
reset_cache=true,
78
verbose=false)
89

910
# printstyled("Iterative LQR\n",
@@ -15,6 +16,7 @@ function ilqr_solve!(prob::ProblemData;
1516
reset!(m_data.model_deriv)
1617
reset!(m_data.obj_deriv)
1718
s_data = prob.s_data
19+
reset_cache && reset!(s_data)
1820

1921
objective!(s_data, m_data, mode=:nominal)
2022
derivatives!(m_data, mode=:nominal)
@@ -35,6 +37,7 @@ function ilqr_solve!(prob::ProblemData;
3537
grad_norm = norm(s_data.gradient, Inf)
3638

3739
# info
40+
s_data.iter[1] += 1
3841
verbose && println(" iter: $i
3942
cost: $(s_data.obj[1])
4043
grad_norm: $(grad_norm)
@@ -99,6 +102,9 @@ function constrained_ilqr_solve!(prob::ProblemData;
99102
# verbose && printstyled("Iterative LQR\n",
100103
# color=:red, bold=true)
101104

105+
# reset solver cache
106+
reset!(prob.s_data)
107+
102108
# reset duals
103109
for (t, λ) in enumerate(prob.m_data.obj.λ)
104110
fill!(λ, 0.0)
@@ -119,6 +125,7 @@ function constrained_ilqr_solve!(prob::ProblemData;
119125
max_iter=max_iter,
120126
obj_tol=obj_tol,
121127
grad_tol=grad_tol,
128+
reset_cache=false,
122129
verbose=verbose)
123130

124131
# update trajectories
@@ -145,16 +152,9 @@ function solve!(prob::ProblemData{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwar
145152
iterative_lqr!(prob, args...; kwargs...)
146153
end
147154

148-
# function solve!(prob::ProblemData{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, x, u; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:Objective{T}}
149-
# iterative_lqr!(prob, x, u; kwargs...)
150-
# end
151-
152155
function solve!(prob::ProblemData{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:AugmentedLagrangianCosts{T}}
153156
constrained_ilqr_solve!(prob, args...; kwargs...)
154157
end
155158

156-
# function solve!(prob::ProblemData{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, x, u; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:AugmentedLagrangianCosts{T}}
157-
# constrained_iterative_lqr!(prob, x, u; kwargs...)
158-
# end
159159

160160

0 commit comments

Comments
 (0)