Skip to content

Commit ad63009

Browse files
authored
Merge pull request #27 from JuliaGNI/remove-reduce-code
Remove reduce code
2 parents 0adcf55 + b333eb3 commit ad63009

File tree

5 files changed

+11
-24
lines changed

5 files changed

+11
-24
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ SafeTestsets = "0.1"
2424
SymbolicUtils = "<3.8.0"
2525
Symbolics = "5, 6"
2626
Zygote = "0.6.73"
27+
julia = "1.10"
2728

2829
[extras]
2930
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

src/build_function/build_function.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,31 +69,16 @@ This first calls `Symbolics.build_function` with the keyword argument `expressio
6969
7070
See the docstrings for those functions for details on how the code is modified.
7171
"""
72-
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr)
72+
function _build_nn_function(eq::EqT, sparams::NeuralNetworkParameters, sinput::Symbolics.Arr)
7373
sc_eq = Symbolics.scalarize(eq)
74-
code = build_function(sc_eq, sinput, values(params)...; expression = Val{true}) |> _reduce_code
74+
code = build_function(sc_eq, sinput, values(sparams)...; expression = Val{true}) |> _reduce
7575
rewritten_code = fix_map_reduce(modify_input_arguments(rewrite_arguments(fix_create_array(code))))
7676
parallelized_code = make_kernel(rewritten_code)
7777
@RuntimeGeneratedFunction(parallelized_code)
7878
end
7979

80-
"""
81-
_reduce_code(code)
82-
83-
Reduce the code.
84-
85-
For some reason `Symbolics.build_function` sometimes returns a tuple and sometimes it doesn't.
86-
87-
This function takes care of this.
88-
If `build_function` returns a tuple `reduce_code` checks which of the expressions is in-place and then returns the other (not in-place) expression.
89-
"""
90-
function _reduce_code(code::Expr)
91-
code
92-
end
93-
94-
function _reduce_code(code::Tuple{Expr, Expr})
95-
contains(string(code[1]), "ˍ₋out") ? code[2] : code[1]
96-
end
80+
_reduce(a) = a
81+
_reduce(a::Tuple) = a[1]
9782

9883
"""
9984
rewrite_arguments(s)

src/build_function/build_function_double_input.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ See the docstrings for those functions for details on how the code is modified.
4242
"""
4343
function _build_nn_function(eq::EqT, params::NeuralNetworkParameters, sinput::Symbolics.Arr, soutput::Symbolics.Arr)
4444
sc_eq = Symbolics.scalarize(eq)
45-
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce_code
45+
code = build_function(sc_eq, sinput, soutput, values(params)...; expression = Val{true}) |> _reduce
4646
rewritten_code = fix_map_reduce(modify_input_arguments2(rewrite_arguments2(fix_create_array(code))))
4747
parallelized_code = make_kernel2(rewritten_code)
4848
@RuntimeGeneratedFunction(parallelized_code)

src/derivatives/gradient.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ using SymbolicNeuralNetworks: SymbolicNeuralNetwork, symbolic_pullback
8080
using AbstractNeuralNetworks
8181
using AbstractNeuralNetworks: params
8282
using LinearAlgebra: norm
83-
using Latexify: latexify
8483
8584
c = Chain(Dense(2, 1, tanh))
8685
nn = SymbolicNeuralNetwork(c)

test/derivatives/jacobian.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ function test_jacobian(n::Integer, T = Float32)
2828

2929
_params = params(NeuralNetwork(c, T))
3030
input = rand(T, n)
31-
@test build_nn_function(g.output, nn)(input, _params) c(input, _params)
32-
@test build_nn_function(derivative(g), nn)(input, _params) ForwardDiff.jacobian(input -> c(input, _params), input)
31+
f = build_nn_function(g.output, nn)
32+
∇f = build_nn_function(derivative(g), nn)
33+
@test f(input, params) c(input, params)
34+
@test ∇f(input, params) ForwardDiff.jacobian(input -> c(input, params), input)
3335
end
3436

35-
for n 1:10
37+
for n 10:1
3638
for T (Float32, Float64)
3739
test_jacobian(n, T)
3840
end

0 commit comments

Comments
 (0)