@@ -4,18 +4,18 @@ function ilqr_solve!(solver::Solver)
4
4
# color=:red, bold=true)
5
5
6
6
# data
7
- policy = solver. policy
7
+ policy = solver. policy
8
8
problem = solver. problem
9
9
reset! (problem. model)
10
- reset! (problem. objective)
10
+ reset! (problem. objective)
11
11
data = solver. data
12
12
solver. options. reset_cache && reset! (data)
13
13
14
- cost! (data, problem,
14
+ cost! (data, problem,
15
15
mode= :nominal )
16
- gradients! (problem,
16
+ gradients! (problem,
17
17
mode= :nominal )
18
- backward_pass! (policy, problem,
18
+ backward_pass! (policy, problem,
19
19
mode= :nominal )
20
20
21
21
obj_prev = data. objective[1 ]
@@ -25,9 +25,9 @@ function ilqr_solve!(solver::Solver)
25
25
line_search= solver. options. line_search,
26
26
verbose= solver. options. verbose)
27
27
if solver. options. line_search != :none
28
- gradients! (problem,
28
+ gradients! (problem,
29
29
mode= :nominal )
30
- backward_pass! (policy, problem,
30
+ backward_pass! (policy, problem,
31
31
mode= :nominal )
32
32
lagrangian_gradient! (data, policy, problem)
33
33
end
@@ -54,8 +54,8 @@ function ilqr_solve!(solver::Solver)
54
54
end
55
55
56
56
function ilqr_solve! (solver:: Solver , states, actions; kwargs... )
57
- initialize_controls! (solver, actions)
58
- initialize_states! (solver, states)
57
+ initialize_controls! (solver, actions)
58
+ initialize_states! (solver, states)
59
59
ilqr_solve! (solver; kwargs... )
60
60
end
61
61
@@ -72,8 +72,8 @@ function lagrangian_gradient!(data::SolverData, policy::PolicyData, problem::Pro
72
72
73
73
for t = 1 : H- 1
74
74
Lx = @views data. gradient[data. indices_state[t]]
75
- Lx .= Qx[t]
76
- Lx .- = p[t]
75
+ Lx .= Qx[t]
76
+ Lx .- = p[t]
77
77
Lu = @views data. gradient[data. indices_action[t]]
78
78
Lu .= Qu[t]
79
79
# data.gradient[data.indices_state[t]] = Qx[t] - p[t] # should always be zero by construction
85
85
"""
86
86
augmented Lagrangian solve
87
87
"""
88
- function constrained_ilqr_solve! (solver:: Solver )
88
+ function constrained_ilqr_solve! (solver:: Solver ; augmented_lagrangian_callback! :: Function = x -> nothing )
89
89
90
90
# verbose && printstyled("Iterative LQR\n",
91
91
# color=:red, bold=true)
92
92
93
- # reset solver cache
94
- reset! (solver. data)
93
+ # reset solver cache
94
+ reset! (solver. data)
95
95
96
- # reset duals
96
+ # reset duals
97
97
for (t, λ) in enumerate (solver. problem. objective. costs. constraint_dual)
98
98
fill! (λ, 0.0 )
99
99
end
@@ -110,24 +110,27 @@ function constrained_ilqr_solve!(solver::Solver)
110
110
ilqr_solve! (solver)
111
111
112
112
# update trajectories
113
- cost! (solver. data, solver. problem,
113
+ cost! (solver. data, solver. problem,
114
114
mode= :nominal )
115
-
115
+
116
116
# constraint violation
117
117
solver. data. max_violation[1 ] <= solver. options. constraint_tolerance && break
118
118
119
119
# dual ascent
120
120
augmented_lagrangian_update! (solver. problem. objective. costs,
121
- scaling_penalty= solver. options. scaling_penalty,
121
+ scaling_penalty= solver. options. scaling_penalty,
122
122
max_penalty= solver. options. max_penalty)
123
+
124
+ # user-defined callback (continuation methods on the models etc.)
125
+ augmented_lagrangian_callback! (solver)
123
126
end
124
127
125
128
return nothing
126
129
end
127
130
128
131
function constrained_ilqr_solve! (solver:: Solver , states, actions; kwargs... )
129
- initialize_controls! (solver, actions)
130
- initialize_states! (solver, states)
132
+ initialize_controls! (solver, actions)
133
+ initialize_states! (solver, states)
131
134
constrained_ilqr_solve! (solver; kwargs... )
132
135
end
133
136
138
141
function solve! (solver:: Solver{T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O} , args... ; kwargs... ) where {T,N,M,NN,MM,MN,NNN,MNN,X,U,D,O<: AugmentedLagrangianCosts{T} }
139
142
constrained_ilqr_solve! (solver, args... ; kwargs... )
140
143
end
141
-
142
-
143
-
0 commit comments