Skip to content

Commit f495dc8

Browse files
committed
Moved remaining tests from docstring tests to build_function directory.
ps -> nn.params. Vector can't be used on Tuple of Vectors (apparently). [1., 2.] -> Vector(1:input_dim) (so that we can deal with flexible input dimensions.
1 parent 02c67ea commit f495dc8

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

src/build_function/build_function_arrays.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,21 @@ Return a function that takes an input, (optionally) an output and neural network
4444
4545
```jldoctest
4646
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
47-
using AbstractNeuralNetworks: Chain, Dense, initialparameters, NeuralNetworkParameters
47+
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
4848
import Random
4949
Random.seed!(123)
5050
5151
c = Chain(Dense(2, 1, tanh))
52-
nn = SymbolicNeuralNetwork(c)
53-
eqs = (a = c(nn.input, nn.params), b = c(nn.input, nn.params).^2)
54-
funcs = build_nn_function(eqs, nn.params, nn.input)
52+
nn = NeuralNetwork(c)
53+
snn = SymbolicNeuralNetwork(nn)
54+
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
55+
funcs = build_nn_function(eqs, snn.params, snn.input)
5556
input = [1., 2.]
56-
ps = initialparameters(c) |> NeuralNetworkParameters
57-
a = c(input, ps)
58-
b = c(input, ps).^2
59-
funcs_evaluated = funcs(input, ps)
60-
61-
(funcs_evaluated.a, funcs_evaluated.b) .≈ (a, b)
57+
funcs_evaluated = funcs(input, nn.params)
6258
6359
# output
6460
65-
(true, true)
61+
(a = [-0.9999386280616135], b = [0.9998772598897417])
6662
```
6763
6864
# Implementation

test/build_function/build_function_arrays.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork
1+
using SymbolicNeuralNetworks: build_nn_function, SymbolicNeuralNetwork, function_valued_parameters
22
using AbstractNeuralNetworks: Chain, Dense, NeuralNetwork
33
using Test
44
import Random
@@ -21,8 +21,44 @@ function build_function_for_array_valued_equation(input_dim::Integer=2, output_d
2121
@test funcs_evaluated_as_vector result_of_standard_computation
2222
end
2323

24+
function build_function_for_named_tuple(input_dim::Integer=2, output_dim::Integer=1)
25+
c = Chain(Dense(input_dim, output_dim, tanh))
26+
nn = NeuralNetwork(c)
27+
snn = SymbolicNeuralNetwork(nn)
28+
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
29+
funcs = build_nn_function(eqs, snn.params, snn.input)
30+
input = Vector(1:input_dim)
31+
a = c(input, nn.params)
32+
b = c(input, nn.params).^2
33+
funcs_evaluated = funcs(input, nn.params)
34+
35+
funcs_evaluated_as_vector = [funcs_evaluated.a, funcs_evaluated.b]
36+
result_of_standard_computation = [a, b]
37+
38+
@test funcs_evaluated_as_vector result_of_standard_computation
39+
end
40+
41+
function function_valued_parameters_for_named_tuple(input_dim::Integer=2, output_dim::Integer=1)
42+
c = Chain(Dense(input_dim, output_dim, tanh))
43+
nn = NeuralNetwork(c)
44+
snn = SymbolicNeuralNetwork(nn)
45+
eqs = (a = c(snn.input, snn.params), b = c(snn.input, snn.params).^2)
46+
funcs = function_valued_parameters(eqs, snn.params, snn.input)
47+
input = Vector(1:input_dim)
48+
a = c(input, nn.params)
49+
b = c(input, nn.params).^2
50+
51+
funcs_evaluated_as_vector = [funcs.a(input, nn.params), funcs.b(input, nn.params)]
52+
result_of_standard_computation = [a, b]
53+
54+
@test funcs_evaluated_as_vector result_of_standard_computation
55+
end
56+
57+
# we test in the following order: `function_valued_parameters` → `build_function` (for `NamedTuple`) → `build_function` (for `Array` of `NamedTuple`s) as this is also how the functions are built.
2458
for input_dim (2, 3)
2559
for output_dim (1, 2)
60+
function_valued_parameters_for_named_tuple(input_dim, output_dim)
61+
build_function_for_named_tuple(input_dim, output_dim)
2662
build_function_for_array_valued_equation(input_dim, output_dim)
2763
end
2864
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SymbolicNeuralNetworks
22
using SafeTestsets
33

4-
@safetestset "Docstring tests. " begin include("doctest.jl") end
4+
# @safetestset "Docstring tests. " begin include("doctest.jl") end
55
@safetestset "Symbolic gradient " begin include("derivatives/symbolic_gradient.jl") end
66
@safetestset "Symbolic Neural network " begin include("derivatives/jacobian.jl") end
77
@safetestset "Symbolic Params " begin include("symbolic_neuralnet/symbolize.jl") end

0 commit comments

Comments
 (0)