Skip to content

Commit 16c90f9

Browse files
committed
verbose forward pass
1 parent 0cf9fc8 commit 16c90f9

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

src/forward_pass.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
function forward_pass!(p_data::PolicyData, m_data::ModelData, s_data::SolverData;
2-
linesearch = :armijo,
3-
α_min = 1.0e-5,
4-
c1 = 1.0e-4,
5-
c2 = 0.9,
6-
max_iter = 25)
2+
linesearch=:armijo,
3+
α_min=1.0e-5,
4+
c1=1.0e-4,
5+
c2=0.9,
6+
max_iter=25,
7+
verbose=false)
78

89
# reset solver status
910
s_data.status[1] = false
@@ -25,16 +26,18 @@ function forward_pass!(p_data::PolicyData, m_data::ModelData, s_data::SolverData
2526
s_data.α[1] = 1.0
2627
iter = 1
2728
while s_data.α[1] >= α_min
28-
iter > max_iter && (@warn "forward pass failure", break)
29+
iter > max_iter && (verbose && (@warn "forward pass failure"), break)
2930

3031
J = Inf
3132
#TODO: remove try-catch
3233
try
3334
rollout!(p_data, m_data, α=s_data.α[1])
3435
J = objective!(s_data, m_data, mode=:current)[1]
3536
catch
36-
@warn "rollout failure"
37-
@show norm(s_data.gradient)
37+
if verbose
38+
@warn "rollout failure"
39+
@show norm(s_data.gradient)
40+
end
3841
end
3942
if (J <= J_prev + c1 * s_data.α[1] * delta_grad_product)
4043
# update nominal
@@ -47,6 +50,6 @@ function forward_pass!(p_data::PolicyData, m_data::ModelData, s_data::SolverData
4750
iter += 1
4851
end
4952
end
50-
s_data.α[1] < α_min && (@warn "line search failure")
53+
s_data.α[1] < α_min && (verbose && (@warn "line search failure"))
5154
end
5255

src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ function ilqr_solve!(prob::ProblemData;
2626
for i = 1:max_iter
2727
forward_pass!(p_data, m_data, s_data,
2828
α_min=α_min,
29-
linesearch=linesearch)
29+
linesearch=linesearch,
30+
verbose=verbose)
3031
if linesearch != :none
3132
derivatives!(m_data, mode=:nominal)
3233
backward_pass!(p_data, m_data, mode=:nominal)

0 commit comments

Comments
 (0)