Skip to content

Commit f5222a1

Browse files
committed
options in solver run_ci
1 parent d59db75 commit f5222a1

File tree

6 files changed

+47
-72
lines changed

6 files changed

+47
-72
lines changed

src/solve.jl

Lines changed: 37 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
1-
function ilqr_solve!(prob::Solver;
2-
max_iter=10,
3-
obj_tol=1.0e-3,
4-
grad_tol=1.0e-3,
5-
α_min=1.0e-5,
6-
linesearch=:armijo,
7-
reset_cache=true,
8-
verbose=false)
1+
function ilqr_solve!(solver::Solver)
92

103
# printstyled("Iterative LQR\n",
114
# color=:red, bold=true)
125

136
# data
14-
p_data = prob.p_data
15-
m_data = prob.m_data
7+
p_data = solver.p_data
8+
m_data = solver.m_data
169
reset!(m_data.model_deriv)
1710
reset!(m_data.obj_deriv)
18-
s_data = prob.s_data
19-
reset_cache && reset!(s_data)
11+
s_data = solver.s_data
12+
solver.options.reset_cache && reset!(s_data)
2013

2114
objective!(s_data, m_data, mode=:nominal)
2215
derivatives!(m_data, mode=:nominal)
2316
backward_pass!(p_data, m_data, mode=:nominal)
2417

2518
obj_prev = s_data.obj[1]
26-
for i = 1:max_iter
19+
for i = 1:solver.options.max_iter
2720
forward_pass!(p_data, m_data, s_data,
28-
α_min=α_min,
29-
linesearch=linesearch,
30-
verbose=verbose)
31-
if linesearch != :none
21+
α_min=solver.options.α_min,
22+
linesearch=solver.options.linesearch,
23+
verbose=solver.options.verbose)
24+
if solver.options.linesearch != :none
3225
derivatives!(m_data, mode=:nominal)
3326
backward_pass!(p_data, m_data, mode=:nominal)
3427
lagrangian_gradient!(s_data, p_data, m_data)
@@ -39,25 +32,25 @@ function ilqr_solve!(prob::Solver;
3932

4033
# info
4134
s_data.iter[1] += 1
42-
verbose && println(" iter: $i
35+
solver.options.verbose && println(" iter: $i
4336
cost: $(s_data.obj[1])
4437
grad_norm: $(grad_norm)
4538
c_max: $(s_data.c_max[1])
4639
α: $(s_data.α[1])")
4740

4841
# check convergence
49-
grad_norm < grad_tol && break
50-
abs(s_data.obj[1] - obj_prev) < obj_tol ? break : (obj_prev = s_data.obj[1])
42+
grad_norm < solver.options.grad_tol && break
43+
abs(s_data.obj[1] - obj_prev) < solver.options.obj_tol ? break : (obj_prev = s_data.obj[1])
5144
!s_data.status[1] && break
5245
end
5346

5447
return nothing
5548
end
5649

57-
function ilqr_solve!(prob::Solver, x, u; kwargs...)
58-
initialize_controls!(prob, u)
59-
initialize_states!(prob, x)
60-
ilqr_solve!(prob; kwargs...)
50+
function ilqr_solve!(solver::Solver, x, u; kwargs...)
51+
initialize_controls!(solver, u)
52+
initialize_states!(solver, x)
53+
ilqr_solve!(solver; kwargs...)
6154
end
6255

6356

@@ -86,75 +79,56 @@ end
8679
"""
8780
augmented Lagrangian solve
8881
"""
89-
function constrained_ilqr_solve!(prob::Solver;
90-
linesearch=:armijo,
91-
max_iter=10,
92-
max_al_iter=10,
93-
α_min=1.0e-5,
94-
obj_tol=1.0e-3,
95-
grad_tol=1.0e-3,
96-
con_tol=1.0e-3,
97-
con_norm_type=Inf,
98-
ρ_init=1.0,
99-
ρ_scale=10.0,
100-
ρ_max=1.0e8,
101-
verbose=false)
82+
function constrained_ilqr_solve!(solver::Solver)
10283

10384
# verbose && printstyled("Iterative LQR\n",
10485
# color=:red, bold=true)
10586

10687
# reset solver cache
107-
reset!(prob.s_data)
88+
reset!(solver.s_data)
10889

10990
# reset duals
110-
for (t, λ) in enumerate(prob.m_data.obj.λ)
91+
for (t, λ) in enumerate(solver.m_data.obj.λ)
11192
fill!(λ, 0.0)
11293
end
11394

11495
# initialize penalty
115-
for (t, ρ) in enumerate(prob.m_data.obj.ρ)
116-
fill!(ρ, ρ_init)
96+
for (t, ρ) in enumerate(solver.m_data.obj.ρ)
97+
fill!(ρ, solver.options.ρ_init)
11798
end
11899

119-
for i = 1:max_al_iter
120-
verbose && println(" al iter: $i")
100+
for i = 1:solver.options.max_al_iter
101+
solver.options.verbose && println(" al iter: $i")
121102

122103
# primal minimization
123-
ilqr_solve!(prob,
124-
linesearch=linesearch,
125-
α_min=α_min,
126-
max_iter=max_iter,
127-
obj_tol=obj_tol,
128-
grad_tol=grad_tol,
129-
reset_cache=false,
130-
verbose=verbose)
104+
ilqr_solve!(solver)
131105

132106
# update trajectories
133-
objective!(prob.s_data, prob.m_data, mode=:nominal)
107+
objective!(solver.s_data, solver.m_data, mode=:nominal)
134108

135109
# constraint violation
136-
prob.s_data.c_max[1] <= con_tol && break
110+
solver.s_data.c_max[1] <= solver.options.con_tol && break
137111

138112
# dual ascent
139-
augmented_lagrangian_update!(prob.m_data.obj,
140-
s=ρ_scale, max_penalty=ρ_max)
113+
augmented_lagrangian_update!(solver.m_data.obj,
114+
s=solver.options.ρ_scale, max_penalty=solver.options.ρ_max)
141115
end
142116

143117
return nothing
144118
end
145119

146-
function constrained_ilqr_solve!(prob::Solver, x, u; kwargs...)
147-
initialize_controls!(prob, u)
148-
initialize_states!(prob, x)
149-
constrained_ilqr_solve!(prob; kwargs...)
120+
function constrained_ilqr_solve!(solver::Solver, x, u; kwargs...)
121+
initialize_controls!(solver, u)
122+
initialize_states!(solver, x)
123+
constrained_ilqr_solve!(solver; kwargs...)
150124
end
151125

152-
function solve!(prob::Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:Objective{T}}
153-
iterative_lqr!(prob, args...; kwargs...)
126+
function solve!(solver::Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:Objective{T}}
127+
iterative_lqr!(solver, args...; kwargs...)
154128
end
155129

156-
function solve!(prob::Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:AugmentedLagrangianCosts{T}}
157-
constrained_ilqr_solve!(prob, args...; kwargs...)
130+
function solve!(solver::Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:AugmentedLagrangianCosts{T}}
131+
constrained_ilqr_solve!(solver, args...; kwargs...)
158132
end
159133

160134

src/solver.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ρ_init::T=1.0
1111
ρ_scale::T=10.0
1212
ρ_max::T=1.0e8
13+
reset_cache::Bool=false
1314
verbose=true
1415
end
1516

test/acrobot.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
end
7272

7373
# ## model
74-
dyn = Dynamics(midpoint_explicit, nx, nu, nw)
74+
dyn = Dynamics(midpoint_explicit, nx, nu, nw=nw)
7575
model = [dyn for t = 1:T-1]
7676

7777
# ## initialization
@@ -84,8 +84,8 @@
8484
# ## objective
8585
ot = (x, u, w) -> 0.1 * dot(x[3:4], x[3:4]) + 0.1 * dot(u, u)
8686
oT = (x, u, w) -> 0.1 * dot(x[3:4], x[3:4])
87-
ct = Cost(ot, nx, nu, nw)
88-
cT = Cost(oT, nx, 0, nw)
87+
ct = Cost(ot, nx, nu, nw=nw)
88+
cT = Cost(oT, nx, 0, nw=nw)
8989
obj = [[ct for t = 1:T-1]..., cT]
9090

9191
# ## constraints

test/car.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
end
1818

1919
# ## model
20-
dyn = Dynamics(midpoint_explicit, nx, nu, nw)
20+
dyn = Dynamics(midpoint_explicit, nx, nu, nw=nw)
2121
model = [dyn for t = 1:T-1]
2222

2323
# ## initialization
@@ -32,8 +32,8 @@
3232
# ## objective
3333
ot = (x, u, w) -> 1.0 * dot(x - xT, x - xT) + 1.0e-2 * dot(u, u)
3434
oT = (x, u, w) -> 1000.0 * dot(x - xT, x - xT)
35-
ct = Cost(ot, nx, nu, nw)
36-
cT = Cost(oT, nx, 0, 0)
35+
ct = Cost(ot, nx, nu, nw=nw)
36+
cT = Cost(oT, nx, 0, nw=0)
3737
obj = [[ct for t = 1:T-1]..., cT]
3838

3939
# ## constraints

test/dynamics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
x + h * pendulum(x, u, w)
1919
end
2020

21-
dyn = Dynamics(euler_explicit, nx, nu, nw)
21+
dyn = Dynamics(euler_explicit, nx, nu, nw=nw)
2222
model = [dyn for t = 1:T-1]
2323

2424
x1 = ones(nx)

test/objective.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
nw = 0
66
ot = (x, u, w) -> dot(x, x) + 0.1 * dot(u, u)
77
oT = (x, u, w) -> 10.0 * dot(x, x)
8-
ct = Cost(ot, nx, nu, nw)
9-
cT = Cost(oT, nx, 0, nw)
8+
ct = Cost(ot, nx, nu, nw=nw)
9+
cT = Cost(oT, nx, 0, nw=nw)
1010
obj = [[ct for t = 1:T-1]..., cT]
1111

1212
J = [0.0]

0 commit comments

Comments
 (0)