Skip to content

Commit 8368761

Browse files
committed
Added test for reshape and create_array routine.
1 parent 8d610cc commit 8368761

File tree

6 files changed

+69
-8
lines changed

6 files changed

+69
-8
lines changed

src/SymbolicNeuralNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ module SymbolicNeuralNetworks
2020
export symbolize
2121
include("utils/symbolize.jl")
2222

23+
include("utils/create_array.jl")
24+
2325
export AbstractSymbolicNeuralNetwork
2426
export SymbolicNeuralNetwork, SymbolicModel
2527
export HamiltonianSymbolicNeuralNetwork, HNNLoss

src/utils/build_function.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ end
2525
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
2626
gen_fun = _build_nn_function(eq, sparams, sinput)
2727
gen_fun_returned(x, ps) = mapreduce(k -> gen_fun(x, ps, k), hcat, axes(x, 2))
28-
gen_fun_returned(x::Union{AbstractVector, Symbolics.Arr}, ps) = gen_fun_returned(reshape(x, length(x), 1), ps)
28+
function gen_fun_returned(x::Union{AbstractVector, Symbolics.Arr}, ps)
29+
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), ps)
30+
# for vectors we do not reshape, as the output may be a matrix
31+
output_not_reshaped
32+
end
2933
# check this! (definitely not correct in all cases!)
30-
gen_fun_returned(x::AbstractArray{<:Number, 3}, ps) = reshape(gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), ps), size(x, 1), size(x, 2), size(x, 3))
34+
function gen_fun_returned(x::AbstractArray{<:Number, 3}, ps)
35+
output_not_reshaped = gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), ps)
36+
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(x, 2), size(x, 3))
37+
end
3138
gen_fun_returned
3239
end
3340

src/utils/build_function2.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@ end
1919
function build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
2020
gen_fun = _build_nn_function(eq, sparams, sinput, soutput)
2121
gen_fun_returned(input, output, ps) = mapreduce(k -> gen_fun(input, output, ps, k), +, axes(input, 2))
22-
gen_fun_returned(input::AT, output::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}} = gen_fun_returned(reshape(input, length(input), 1), reshape(output, length(output), 1), ps)
23-
gen_fun_returned(input::AT, output::AT, ps) where {T, AT <: AbstractArray{T, 3}} = gen_fun_returned(reshape(input, size(input, 1), size(input, 2) * size(input, 3)), reshape(output, size(output, 1), size(output, 2) * size(output, 3)), ps)
22+
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: Union{AbstractVector, Symbolics.Arr}}
23+
output_not_reshaped = gen_fun_returned(reshape(x, length(x), 1), reshape(y, length(y), 1), ps)
24+
# for vectors we do not reshape, as the output may be a matrix
25+
output_not_reshaped
26+
end
27+
# check this! (definitely not correct in all cases!)
28+
function gen_fun_returned(x::AT, y::AT, ps) where {AT <: AbstractArray{<:Number, 3}}
29+
output_not_reshaped = gen_fun_returned(reshape(x, size(x, 1), size(x, 2) * size(x, 3)), reshape(y, size(y, 1), size(y, 2) * size(y, 3)), ps)
30+
reshape(output_not_reshaped, size(output_not_reshaped, 1), size(x, 2), size(x, 3))
31+
end
2432
gen_fun_returned
2533
end
2634

src/utils/create_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
function Symbolics.SymbolicUtils.Code.create_array(::Type{<:Base.ReshapedArray{T, N, P}}, S, nd::Val, d::Val, elems...) where {T, N, P}
2+
Symbolics.SymbolicUtils.Code.create_array(P, S, nd, d, elems...)
3+
end

test/reshape_test.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using SymbolicNeuralNetworks
2+
using AbstractNeuralNetworks
3+
using Symbolics
4+
using Test
5+
6+
function set_up_network()
7+
c = Chain(Dense(2, 3))
8+
nn = SymbolicNeuralNetwork(c)
9+
soutput = nn.model(nn.input, nn.params)
10+
nn_cpu = NeuralNetwork(c)
11+
nn, soutput, nn_cpu
12+
end
13+
14+
function test_for_input()
15+
nn, soutput, nn_cpu = set_up_network()
16+
input = rand(2, 5)
17+
input2 = reshape((@view input[:, 1:2]), 2, 1, 2)
18+
built_function = build_nn_function(soutput, nn.params, nn.input)
19+
outputs = built_function(input2, nn_cpu.params)
20+
for i in 1:2
21+
@test nn.model(input[:, i], nn_cpu.params) outputs[:, 1, i]
22+
end
23+
end
24+
25+
function test_for_input_and_output()
26+
nn, soutput2, nn_cpu = set_up_network()
27+
input = rand(2, 5)
28+
output = rand(3, 5)
29+
input2 = reshape((@view input[:, 1:2]), 2, 1, 2)
30+
output2 = reshape((@view input[:, 1:3]), 3, 1, 2)
31+
@variables soutput[1:3]
32+
built_function = build_nn_function((soutput - soutput2).^2, nn.params, nn.input, soutput)
33+
outputs = built_function(input2, output2, nn_cpu.params)
34+
for i in 1:2
35+
@test (nn.model(input[:, i], nn_cpu.params) - output[:, i]).^2 outputs[:, 1, i]
36+
end
37+
end
38+
39+
test_for_input()
40+
test_for_input_and_output()

test/runtests.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using SafeTestsets
33
using Test
44

55
@safetestset "Docstring tests. " begin include("doctest.jl") end
6-
@safetestset "Symbolic gradient " begin include("symbolic_gradient.jl") end
7-
@safetestset "Symbolic Neural network " begin include("neural_network_derivative.jl") end
8-
@safetestset "Symbolic Params " begin include("test_params.jl") end
9-
# @safetestset "HNN Loss " begin include("test_hnn_loss_pullback.jl") end
6+
@safetestset "Symbolic gradient " begin include("symbolic_gradient.jl") end
7+
@safetestset "Symbolic Neural network " begin include("neural_network_derivative.jl") end
8+
@safetestset "Symbolic Params " begin include("test_params.jl") end
9+
# @safetestset "HNN Loss " begin include("test_hnn_loss_pullback.jl") end
10+
@safetestset "Check if reshape works in the correct way with the generated functions. " begin include("reshape_test.jl") end

0 commit comments

Comments
 (0)