Skip to content

Commit 4eb16db

Browse files
committed
adding augmented lagrangian callback
1 parent 57345b6 commit 4eb16db

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

src/solve.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ function ilqr_solve!(solver::Solver)
44
# color=:red, bold=true)
55

66
# data
7-
policy = solver.policy
7+
policy = solver.policy
88
problem = solver.problem
99
reset!(problem.model)
10-
reset!(problem.objective)
10+
reset!(problem.objective)
1111
data = solver.data
1212
solver.options.reset_cache && reset!(data)
1313

14-
cost!(data, problem,
14+
cost!(data, problem,
1515
mode=:nominal)
16-
gradients!(problem,
16+
gradients!(problem,
1717
mode=:nominal)
18-
backward_pass!(policy, problem,
18+
backward_pass!(policy, problem,
1919
mode=:nominal)
2020

2121
obj_prev = data.objective[1]
@@ -25,9 +25,9 @@ function ilqr_solve!(solver::Solver)
2525
line_search=solver.options.line_search,
2626
verbose=solver.options.verbose)
2727
if solver.options.line_search != :none
28-
gradients!(problem,
28+
gradients!(problem,
2929
mode=:nominal)
30-
backward_pass!(policy, problem,
30+
backward_pass!(policy, problem,
3131
mode=:nominal)
3232
lagrangian_gradient!(data, policy, problem)
3333
end
@@ -54,8 +54,8 @@ function ilqr_solve!(solver::Solver)
5454
end
5555

5656
function ilqr_solve!(solver::Solver, states, actions; kwargs...)
57-
initialize_controls!(solver, actions)
58-
initialize_states!(solver, states)
57+
initialize_controls!(solver, actions)
58+
initialize_states!(solver, states)
5959
ilqr_solve!(solver; kwargs...)
6060
end
6161

@@ -72,8 +72,8 @@ function lagrangian_gradient!(data::SolverData, policy::PolicyData, problem::Pro
7272

7373
for t = 1:H-1
7474
Lx = @views data.gradient[data.indices_state[t]]
75-
Lx .= Qx[t]
76-
Lx .-= p[t]
75+
Lx .= Qx[t]
76+
Lx .-= p[t]
7777
Lu = @views data.gradient[data.indices_action[t]]
7878
Lu .= Qu[t]
7979
# data.gradient[data.indices_state[t]] = Qx[t] - p[t] # should always be zero by construction
@@ -85,15 +85,15 @@ end
8585
"""
8686
augmented Lagrangian solve
8787
"""
88-
function constrained_ilqr_solve!(solver::Solver)
88+
function constrained_ilqr_solve!(solver::Solver; augmented_lagrangian_callback!::Function=x->nothing)
8989

9090
# verbose && printstyled("Iterative LQR\n",
9191
# color=:red, bold=true)
9292

93-
# reset solver cache
94-
reset!(solver.data)
93+
# reset solver cache
94+
reset!(solver.data)
9595

96-
# reset duals
96+
# reset duals
9797
for (t, λ) in enumerate(solver.problem.objective.costs.constraint_dual)
9898
fill!(λ, 0.0)
9999
end
@@ -110,24 +110,27 @@ function constrained_ilqr_solve!(solver::Solver)
110110
ilqr_solve!(solver)
111111

112112
# update trajectories
113-
cost!(solver.data, solver.problem,
113+
cost!(solver.data, solver.problem,
114114
mode=:nominal)
115-
115+
116116
# constraint violation
117117
solver.data.max_violation[1] <= solver.options.constraint_tolerance && break
118118

119119
# dual ascent
120120
augmented_lagrangian_update!(solver.problem.objective.costs,
121-
scaling_penalty=solver.options.scaling_penalty,
121+
scaling_penalty=solver.options.scaling_penalty,
122122
max_penalty=solver.options.max_penalty)
123+
124+
# user-defined callback (continuation methods on the models etc.)
125+
augmented_lagrangian_callback!(solver)
123126
end
124127

125128
return nothing
126129
end
127130

128131
function constrained_ilqr_solve!(solver::Solver, states, actions; kwargs...)
129-
initialize_controls!(solver, actions)
130-
initialize_states!(solver, states)
132+
initialize_controls!(solver, actions)
133+
initialize_states!(solver, states)
131134
constrained_ilqr_solve!(solver; kwargs...)
132135
end
133136

@@ -138,6 +141,3 @@ end
138141
function solve!(solver::Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O}, args...; kwargs...) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<:AugmentedLagrangianCosts{T}}
139142
constrained_ilqr_solve!(solver, args...; kwargs...)
140143
end
141-
142-
143-

0 commit comments

Comments
 (0)