Skip to content

Commit 9f285f7

Browse files
committed
Add problems tests
1 parent f8aadb2 commit 9f285f7

File tree

4 files changed

+141
-3
lines changed

4 files changed

+141
-3
lines changed

src/problem.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,20 @@ function Problem(model::L, obj::O, xf::AbstractVector, tf;
6262
x0=zero(xf), N::Int=length(obj),
6363
X0=[x0*NaN for k = 1:N],
6464
U0=[@SVector zeros(size(model)[2]) for k = 1:N-1],
65-
dt=fill((tf-t0)/(N-1),N),
65+
dt=fill((tf-t0)/(N-1),N-1),
6666
integration=DEFAULT_Q) where {L,O}
6767
n,m = size(model)
6868
if dt isa Real
6969
dt = fill(dt,N)
7070
end
71+
@assert sum(dt[1:N-1]) tf "Time steps are inconsistent with final time"
7172
if X0 isa AbstractMatrix
7273
X0 = [X0[:,k] for k = 1:size(X0,2)]
7374
end
7475
if U0 isa AbstractMatrix
7576
U0 = [U0[:,k] for k = 1:size(U0,2)]
7677
end
77-
t = range(t0, tf, length=N)
78+
t = pushfirst!(cumsum(dt), 0)
7879
Z = Traj(X0,U0,dt,t)
7980

8081
Problem{integration}(model, obj, constraints, SVector{n}(x0), SVector{n}(xf),
@@ -112,6 +113,13 @@ Get the state trajectory
112113
"
113114
states(prob::Problem) = states(prob.Z)
114115

116+
"""
117+
get_times(::Problem)
118+
119+
Get the times for all the knot points in the problem.
120+
"""
121+
@inline get_times(prob::Problem) = get_times(get_trajectory(prob))
122+
115123
"```julia
116124
initial_trajectory!(::Problem, Z)
117125
initial_trajectory!(::AbstractSolver, Z)

src/trajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161
function Traj(n::Int, m::Int, dt::AbstractFloat, N::Int; equal=false)
6262
x = NaN*@SVector ones(n)
6363
u = @SVector zeros(m)
64-
Traj(x,u,dt,N,equal)
64+
Traj(x,u,dt,N; equal=equal)
6565
end
6666

6767
function Traj(x::SVector, u::SVector, dt::AbstractFloat, N::Int; equal=false)

test/problems_tests.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Model and discretization
2+
model = Cartpole()
3+
n,m = size(model)
4+
N = 11
5+
tf = 5.
6+
dt = tf/(N-1)
7+
8+
# Initial and Final conditions
9+
x0 = @SVector zeros(n)
10+
xf = @SVector [0, pi, 0, 0]
11+
12+
# Objective
13+
Q = 1.0e-2*Diagonal(@SVector ones(n))
14+
Qf = 100.0*Diagonal(@SVector ones(n))
15+
R = 1.0e-1*Diagonal(@SVector ones(m))
16+
obj = LQRObjective(Q,R,Qf,xf,N)
17+
18+
# Constraints
19+
u_bnd = 3.0
20+
conSet = ConstraintList(n,m,N)
21+
bnd = BoundConstraint(n,m, u_min=-u_bnd, u_max=u_bnd)
22+
goal = GoalConstraint(xf)
23+
add_constraint!(conSet, bnd, 1:N-1)
24+
add_constraint!(conSet, goal, N:N)
25+
26+
# Initial conditions
27+
X0 = [@SVector fill(0.0,n) for k = 1:N]
28+
u0 = @SVector fill(0.01,m)
29+
U0 = [u0 for k = 1:N-1]
30+
Z = Traj(X0,U0, fill(dt, N))
31+
32+
# Inner constructor
33+
prob = Problem{RK3}(model, obj, conSet, x0, xf, Z, N, 0.0, tf)
34+
@test TO.integration(prob) == RK3
35+
@test prob.x0 == x0
36+
@test prob.xf == xf
37+
@test prob.constraints === conSet
38+
add_constraint!(conSet, goal, N-1)
39+
@test length(TO.get_constraints(prob)) == length(conSet)
40+
@test TO.num_constraints(TO.get_constraints(prob)) === conSet.p
41+
@test prob.obj === obj
42+
@test prob.tf tf
43+
@test prob.N == N
44+
@test states(prob) X0
45+
@test controls(prob) U0
46+
47+
# Alternate constructor
48+
prob = Problem(model, obj, xf, tf, x0=x0, constraints=conSet, X0=X0, U0=U0)
49+
@test TO.integration(prob) == RK3
50+
@test prob.x0 == x0
51+
@test prob.xf == xf
52+
@test prob.constraints === conSet
53+
add_constraint!(conSet, goal, N-1)
54+
@test length(TO.get_constraints(prob)) == length(conSet)
55+
@test TO.num_constraints(TO.get_constraints(prob)) === conSet.p
56+
@test prob.obj === obj
57+
@test prob.tf tf
58+
@test prob.N == N
59+
@test states(prob) X0
60+
@test controls(prob) U0
61+
62+
# Change integration
63+
prob = Problem(model, obj, xf, tf, x0=x0, constraints=conSet, integration=RK2)
64+
@test TO.integration(prob) == RK2
65+
66+
# Test defaults
67+
prob = Problem(model, obj, xf, tf)
68+
@test prob.x0 == zero(x0)
69+
@test prob.N == N
70+
@test all(all.(isnan, states(prob)))
71+
@test controls(prob) == [zeros(m) for k = 1:N-1]
72+
@test isempty(prob.constraints)
73+
@test TO.integration(prob) == RK3
74+
@test TO.get_times(prob) range(0, tf; step=dt)
75+
76+
# Set initial trajectories
77+
initial_states!(prob, 2 .* X0)
78+
@test states(prob) 2 .* X0
79+
initial_controls!(prob, 2 .* U0)
80+
@test controls(prob) 2 .* U0
81+
82+
initial_trajectory!(prob, Z)
83+
@test controls(prob) U0
84+
@test states(prob) X0
85+
86+
# Use 2D matrices
87+
X0_mat = hcat(X0...)
88+
U0_mat = hcat(U0...)
89+
initial_states!(prob, 3*X0)
90+
initial_controls!(prob, 4*U0)
91+
@test states(prob) 3 .* X0
92+
@test controls(prob) 4 .* U0
93+
94+
prob = Problem(model, obj, xf, tf, X0=X0_mat, U0=U0_mat)
95+
@test states(prob) X0
96+
@test controls(prob) U0
97+
98+
# Use variable time steps
99+
dts = rand(N-1)
100+
dts = dts / sum(dts) * tf
101+
prob = Problem(model, obj, xf, tf, dt=dts)
102+
times = TO.get_times(prob)
103+
@test times[end] tf
104+
@test times[1] 0
105+
@test times[2] dts[1]
106+
@test diff(times) dts
107+
108+
# Test initial and final conditions
109+
prob = Problem(model, obj, Vector(xf), tf, x0=Vector(x0))
110+
@test prob.x0 x0
111+
@test prob.xf xf
112+
prob = Problem(model, obj, MVector(xf), tf, x0=MVector(x0))
113+
@test prob.x0 x0
114+
@test prob.xf xf
115+
@test prob.x0 isa MVector
116+
@test prob.xf isa MVector
117+
118+
x0_ = rand(n)
119+
TO.set_initial_state!(prob, x0_)
120+
@test prob.x0 x0_
121+
@test_throws DimensionMismatch TO.set_initial_state!(prob, rand(2n))

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ end
2222
include("constraint_sets.jl")
2323
end
2424

25+
@testset "Problems" begin
26+
include("problems_tests.jl")
27+
end
28+
2529
@testset "Utils" begin
2630
include("trajectories.jl")
2731
end
32+
33+
@testset "NLP" begin
34+
include("nlp_tests.jl")
35+
include("moi_test.jl")
36+
end

0 commit comments

Comments
 (0)