Skip to content

Commit 5797dcb

Browse files
Merge pull request #34 from SciML/smc/macro
Compatibility with `@mtkmodel`
2 parents 9d4aa15 + a0c5512 commit 5797dcb

File tree

5 files changed

+52
-5
lines changed

5 files changed

+52
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkitNeuralNets"
22
uuid = "f162e290-f571-43a6-83d9-22ecc16da15f"
33
authors = ["Sebastian Micluța-Câmpeanu <sebastian.mc95@proton.me> and contributors"]
4-
version = "1.1.0"
4+
version = "1.2.0"
55

66
[deps]
77
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"

src/ModelingToolkitNeuralNets.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export NeuralNetworkBlock, multi_layer_feed_forward
1313
include("utils.jl")
1414

1515
"""
16-
NeuralNetworkBlock(n_input = 1, n_output = 1;
16+
NeuralNetworkBlock(; n_input = 1, n_output = 1,
1717
chain = multi_layer_feed_forward(n_input, n_output),
1818
rng = Xoshiro(0),
1919
init_params = Lux.initialparameters(rng, chain),
@@ -22,8 +22,7 @@ include("utils.jl")
2222
2323
Create an `ODESystem` with a neural network inside.
2424
"""
25-
function NeuralNetworkBlock(n_input = 1,
26-
n_output = 1;
25+
function NeuralNetworkBlock(; n_input = 1, n_output = 1,
2726
chain = multi_layer_feed_forward(n_input, n_output),
2827
rng = Xoshiro(0),
2928
init_params = Lux.initialparameters(rng, chain),
@@ -46,6 +45,12 @@ function NeuralNetworkBlock(n_input = 1,
4645
return ude_comp
4746
end
4847

48+
# added to avoid a breaking change from moving n_input & n_output in kwargs
49+
# https://github.com/SciML/ModelingToolkitNeuralNets.jl/issues/32
50+
function NeuralNetworkBlock(n_input, n_output = 1; kwargs...)
51+
NeuralNetworkBlock(; n_input, n_output, kwargs...)
52+
end
53+
4954
function lazyconvert(T, x::Symbolics.Arr)
5055
Symbolics.array_term(convert, T, x, size = size(x))
5156
end

test/lotka_volterra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114114

115115
@test res.objective < 1
116116

117-
res_p = SciMLStructures.replace(Tunable(), prob.p, res)
117+
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
118118
res_prob = remake(prob, p = res_p)
119119
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
120120

test/macro.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using ModelingToolkit, Symbolics
2+
using ModelingToolkit: t_nounits as t, D_nounits as D
3+
using OrdinaryDiffEq
4+
using ModelingToolkitNeuralNets
5+
using ModelingToolkitStandardLibrary.Blocks
6+
using Lux
7+
8+
@mtkmodel Friction_UDE begin
9+
@variables begin
10+
y(t) = 0.0
11+
end
12+
@parameters begin
13+
Fu
14+
end
15+
@components begin
16+
nn_in = RealInputArray(nin = 1)
17+
nn_out = RealOutputArray(nout = 1)
18+
end
19+
@equations begin
20+
D(y) ~ Fu - nn_in.u[1]
21+
y ~ nn_out.u[1]
22+
end
23+
end
24+
25+
@mtkmodel TestFriction_UDE begin
26+
@components begin
27+
friction_ude = Friction_UDE(Fu = 120.0)
28+
nn = NeuralNetworkBlock(n_input = 1, n_output = 1)
29+
end
30+
@equations begin
31+
connect(friction_ude.nn_in, nn.output)
32+
connect(friction_ude.nn_out, nn.input)
33+
end
34+
end
35+
36+
@mtkbuild sys = TestFriction_UDE()
37+
38+
prob = ODEProblem(sys, [], (0, 1.0), [])
39+
sol = solve(prob, Rodas4())
40+
41+
@test SciMLBase.successful_retcode(sol)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ using SafeTestsets
55
@testset verbose=true "ModelingToolkitNeuralNets.jl" begin
66
@safetestset "QA" include("qa.jl")
77
@safetestset "Basic" include("lotka_volterra.jl")
8+
@safetestset "MTK model macro compatibility" include("macro.jl")
89
end

0 commit comments

Comments
 (0)