Skip to content

Commit 5949417

Browse files
committed
Fixes and clean up throughout.
1 parent e5e357d commit 5949417

File tree

8 files changed

+32
-39
lines changed

8 files changed

+32
-39
lines changed

src/hamiltonian.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
#=
1+
"""
22
symbolic_hamitonian is a functions that creates a symbolic hamiltonian from any hamiltonian.
33
The output of the function is
44
- the symbolics parameters used to build the hamiltonian (i.e t, q, p),
55
- the symbolic expression of the hamiltonian,
66
- the function generated from the symbolic hamiltonian.
7-
=#
8-
7+
"""
98
function symbolic_hamiltonian(H::Base.Callable, dim::Int, params::Union{Tuple, NamedTuple, AbstractArray})
109

1110
RuntimeGeneratedFunctions.init(@__MODULE__)
@@ -18,10 +17,10 @@ function symbolic_hamiltonian(H::Base.Callable, dim::Int, params::Union{Tuple, N
1817
sparams = symbolize(params; redundancy = false)[1] # for the parameters
1918

2019
# create the symbolic hamiltonian
21-
sH = H(p, st, q, sparams)
20+
sH = H(st, q, p, sparams)
2221

2322
# create the related code
24-
code_H = build_function(sH, p, st, q, develop(sparams)...)
23+
code_H = build_function(sH, st, q, p, develop(sparams)...)
2524

2625
# rewrite the code to take directly paramters
2726
rewrite_code_H = rewrite_hamiltonian(code_H, (q, p, st), sparams)
@@ -31,8 +30,3 @@ function symbolic_hamiltonian(H::Base.Callable, dim::Int, params::Union{Tuple, N
3130

3231
return st, q, p, sparams, sH, gH
3332
end
34-
35-
36-
37-
38-

src/lagrangian.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
#=
1+
"""
22
symbolic_lagrangian is a functions that creates a symbolic lagrangian from any lagrangian.
33
The output of the function is
44
- the symbolics parameters used to build the hamiltonian (i.e t, x, v),
55
- the symbolic expression of the lagrangian,
66
- the function generated from the symbolic lagrangian.
7-
=#
8-
7+
"""
98
function symbolic_lagrangian(L::Base.Callable, dim::Int, params::Union{Tuple, NamedTuple, AbstractArray})
109

1110
RuntimeGeneratedFunctions.init(@__MODULE__)
@@ -31,4 +30,3 @@ function symbolic_lagrangian(L::Base.Callable, dim::Int, params::Union{Tuple, Na
3130

3231
return st, x, v, sparams, sL, gL
3332
end
34-

src/utils/build_function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function build_eval(f::Base.Callable, args...; params = params::Union{Tuple, Nam
1515
end
1616

1717

18-
function build_hamiltonien(H::Base.Callable, dim::Int, params::Union{Tuple, NamedTuple, AbstractArray})
18+
function build_hamiltonian(H::Base.Callable, dim::Int, params::Union{Tuple, NamedTuple, AbstractArray})
1919

2020
#compute the symplectic matrix
2121
sympmatrix = symplecticMatrix(dim)

src/utils/rewrite_code.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,40 @@
33
=#
44

55
function rewrite_code(expr, sargs, sparams = nothing, fun_name = "")
6-
expr = Meta.parse(replace(string(expr), "var" => ""))
7-
expr = Meta.parse(replace(string(expr), r"\""=>""))
6+
expr_str = string(expr)
7+
expr_str = replace(expr_str, "var" => "")
8+
expr_str = replace(expr_str, r"\"" => "")
89
for arg in sargs
9-
str_symbol = replace(string(arg), r"\[.*"=>"")
10+
str_symbol = replace(string(arg), r"\[.*" => "")
1011
if str_symbol[1] == '(' && str_symbol[end] == ')'
1112
str_symbol = str_symbol[2:end-1]
1213
end
1314
track = get_track(sargs, arg, "args")[2]
14-
expr = Meta.parse(replace(string(expr), str_symbol => track))
15+
expr_str = replace(expr_str, str_symbol => track)
1516
end
1617
for p in develop(sparams)
17-
sparams
18-
p
19-
str_symbol = replace(string(p), r"\[.*"=>"")
18+
str_symbol = replace(string(p), r"\[.*" => "")
2019
track = get_track(sparams, p, "params")[2]
21-
expr = Meta.parse(replace(string(expr), str_symbol => track))
20+
expr_str = replace(expr_str, str_symbol => track)
2221
end
23-
expr = Meta.parse(replace(string(expr), r"function.*" => "function "*fun_name*"(args, params)\n"))
22+
expr_str = replace(expr_str, r"function.*" => "function " * fun_name * "(args, params)\n")
23+
expr = Meta.parse(expr_str)
2424
end
2525

2626
function rewrite_hamiltonian(expr, args, sparams = nothing, fun_name = "shamiltonian")
2727
expr = rewrite_code(expr, args, sparams, fun_name)
28-
string_expr = replace(string(expr), "args[1]" => "q")
28+
string_expr = string(expr)
29+
string_expr = replace(string_expr, "args[1]" => "q")
2930
string_expr = replace(string_expr, "args[2]" => "p")
3031
string_expr = replace(string_expr, "args[3]" => "t")
31-
string_expr = replace(string_expr, "args" => "p, t, q")
32+
string_expr = replace(string_expr, "args" => "t, q, p")
3233
expr = Meta.parse(string_expr)
3334
end
3435

3536
function rewrite_lagrangian(expr, args, sparams = nothing, fun_name = "slagrangian")
3637
expr = rewrite_code(expr, args, sparams, fun_name)
37-
string_expr = replace(string(expr), "args[1]" => "x")
38+
string_expr = string(expr)
39+
string_expr = replace(string_expr, "args[1]" => "x")
3840
string_expr = replace(string_expr, "args[2]" => "v")
3941
string_expr = replace(string_expr, "args[3]" => "t")
4042
string_expr = replace(string_expr, "args" => "t, x, v")
@@ -43,7 +45,8 @@ end
4345

4446
function rewrite_neuralnetwork(expr, args, sparams)
4547
expr = rewrite_code(expr, args, sparams)
46-
string_expr = replace(string(expr), "args[1]" => "x")
48+
string_expr = string(expr)
49+
string_expr = replace(string_expr, "args[1]" => "x")
4750
string_expr = replace(string_expr, "args" => "x")
4851
expr = Meta.parse(string_expr)
4952
end

test/test_hamiltonian.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@ using Symbolics
33
using Test
44

55
parameters = (k=1, )
6-
function hamiltonian(p, t, q, params)
6+
7+
function hamiltonian(t, q, p, params)
78
p[1]^2 / 2 + params.k * q[1]^2 / 2
89
end
910

10-
symbolize(parameters)
11-
1211
st̃, sq̃, sp̃, sparams, shamiltonian, hamiltonian_function = symbolic_hamiltonian(hamiltonian, 2, parameters)
1312

1413
@variables st
@@ -24,5 +23,5 @@ st̃, sq̃, sp̃, sparams, shamiltonian, hamiltonian_function = symbolic_hamilto
2423
@test isequal(shamiltonian, Num((1//2)*(p[1]^2) + (1//2)*k*(q[1]^2)))
2524

2625
t, q, p = 2, 0.5, 0.7
27-
@test hamiltonian_function(p, t, q, parameters) == hamiltonian(p, t, q, parameters)
26+
@test hamiltonian_function(t, q, p, parameters) == hamiltonian(t, q, p, parameters)
2827

test/test_lagrangian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Symbolics
33
using Test
44

55
parameters = (k=1, )
6+
67
function lagrangian(t, q, v, params)
78
v[1]^2 / 2 - params.k * q[1]^2 / 2
89
end

test/test_params.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ sparams = symbolize(params)[1]
1111

1212
verified_sparams = ( (W = W_1, b = b_1), (W = (W_2 , W_3), c = c_1), M_1, (X_1, X_2))
1313

14-
@test sparams === verified_sparams
14+
@test sparams === verified_sparams

test/test_rewrite.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@ using Test
55
@variables W1[1:2,1:2] W2[1:2,1:2] b1[1:2] b2[1:2]
66
sparams = ((W = W1, b = b1), (W = W2, b = b2))
77

8-
@variables SX, SY
9-
sargs = [SX, SY]
8+
@variables st
9+
@variables sargs(st)[1:2]
1010

1111
output = sparams[2].W * tanh.(sparams[1].W * sargs + sparams[1].b) + sparams[2].b
1212

1313
code_output = build_function(output, sargs..., develop(sparams)...)[2]
14-
rewrite_ouput = eval(rewrite_code(code_output, Tuple(sargs), sparams, "OUTPUT"))
14+
rewrite_output = eval(rewrite_code(code_output, Tuple(sargs), sparams, "OUTPUT"))
1515

1616
params = ((W = [1 3; 2 2], b = [1, 0]), (W = [1 1; 0 2], b = [1, 0]))
1717
args = [1, 0.2]
1818

19-
@test_nowarn rewrite_ouput(args, params)
20-
21-
19+
@test_nowarn rewrite_output(args, params)

0 commit comments

Comments
 (0)