|
| 1 | +using Flux, Fluxperimental, Reactant, Enzyme |
| 2 | +using Test |
| 3 | +@testset "Reactant + Flux" begin |
| 4 | + |
| 5 | +@testset "simple forwards" begin |
| 6 | + img = rand32(28, 28, 1, 2) |
| 7 | + mlp = Chain(Flux.flatten, Dense(28^2 => 32, tanh), Dense(32 => 10)) |
| 8 | + y1 = mlp(img) |
| 9 | + @test y1 isa Matrix |
| 10 | + |
| 11 | + re_mlp = Reactor(mlp) # signal to use Reactant |
| 12 | + y2 = re_mlp(img) |
| 13 | + @test y2 isa ConcreteRArray |
| 14 | + @test y1 ≈ Array(y2) |
| 15 | + |
| 16 | + y3 = re_mlp(img) |
| 17 | + @test y1 ≈ Array(y3) |
| 18 | + @test re_mlp.fwd_count == 2 # re-used without recompilation |
| 19 | + |
| 20 | + img10 = rand32(28, 28, 1, 10) |
| 21 | + y10 = mlp(img10) |
| 22 | + y11 = re_mlp(img10) # re-compiles for the new size |
| 23 | + @test y10 ≈ Array(y11) |
| 24 | +end |
| 25 | + |
| 26 | +@testset "simple gradient" begin |
| 27 | + img = rand32(28, 28, 1, 2) |
| 28 | + mlp = Chain(Flux.flatten, Dense(28^2 => 32, tanh), Dense(32 => 10)) |
| 29 | + loss1(m, x) = sum(abs2, m(x)) |
| 30 | + |
| 31 | + g1 = Flux.gradient(loss1, mlp, img)[1].layers[2].bias |
| 32 | + @test g1 isa Vector |
| 33 | + |
| 34 | + re_mlp = Reactor(mlp) |
| 35 | + dup_mlp = Duplicated(mlp); |
| 36 | + g2 = Flux.gradient(loss1, dup_mlp, Const(img))[1].layers[2].bias # Enzyme |
| 37 | + @test g2 ≈ g1 |
| 38 | + @test g2 isa Vector |
| 39 | + |
| 40 | + re_mlp = Reactor(mlp); |
| 41 | + g3 = Flux.gradient(loss1, re_mlp, Const(img))[1].layers[2].bias |
| 42 | + @test Array(g3) ≈ g1 |
| 43 | + g4 = Flux.gradient(loss1, re_mlp, Const(img))[1].layers[2].bias |
| 44 | + @test Array(g4) ≈ g1 |
| 45 | + @test re_mlp.grad_count == 2 # re-used without recompilation |
| 46 | +end |
| 47 | + |
| 48 | +#= |
| 49 | +
|
| 50 | +simple gradient: Error During Test at REPL[59]:1 |
| 51 | + Got exception outside of a @test |
| 52 | + Constant memory is stored (or returned) to a differentiable variable. |
| 53 | + As a result, Enzyme cannot provably ensure correctness and throws this error. |
| 54 | + This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity). |
| 55 | + If Enzyme should be able to prove this use non-differentable, open an issue! |
| 56 | + To work around this issue, either: |
| 57 | + a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or |
| 58 | + b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance. |
| 59 | + Mismatched activity for: store i8* %17, i8* addrspace(11)* %.repack, align 8, !dbg !165, !tbaa !118, !alias.scope !121, !noalias !166 const val: %17 = load i8*, i8* addrspace(11)* %16, align 8, !dbg !112, !tbaa !118, !alias.scope !121, !noalias !122, !enzyme_type !123, !enzymejl_byref_BITS_VALUE !0, !enzymejl_source_type_Ptr\7BFloat32\7D !0 |
| 60 | + value=Unknown object of type Ptr{Float32} |
| 61 | + llvalue= %17 = load i8*, i8* addrspace(11)* %16, align 8, !dbg !112, !tbaa !118, !alias.scope !121, !noalias !122, !enzyme_type !123, !enzymejl_byref_BITS_VALUE !0, !enzymejl_source_type_Ptr\7BFloat32\7D !0 |
| 62 | +
|
| 63 | + Stacktrace: |
| 64 | + [1] reshape |
| 65 | + @ ./reshapedarray.jl:60 |
| 66 | + [2] reshape |
| 67 | + @ ./reshapedarray.jl:129 |
| 68 | + [3] reshape |
| 69 | + @ ./reshapedarray.jl:128 |
| 70 | + [4] flatten |
| 71 | + @ ~/.julia/packages/MLUtils/LmmaQ/src/utils.jl:504 |
| 72 | + [5] flatten |
| 73 | + @ ~/.julia/dev/Flux/src/layers/stateless.jl:105 |
| 74 | + [6] macro expansion |
| 75 | + @ ~/.julia/dev/Flux/src/layers/basic.jl:68 |
| 76 | + [7] _applychain |
| 77 | + @ ~/.julia/dev/Flux/src/layers/basic.jl:68 |
| 78 | +
|
| 79 | + Stacktrace: |
| 80 | + [1] reshape |
| 81 | + @ ./reshapedarray.jl:60 [inlined] |
| 82 | + [2] reshape |
| 83 | + @ ./reshapedarray.jl:129 [inlined] |
| 84 | + [3] reshape |
| 85 | + @ ./reshapedarray.jl:128 [inlined] |
| 86 | + [4] flatten |
| 87 | + @ ~/.julia/packages/MLUtils/LmmaQ/src/utils.jl:504 [inlined] |
| 88 | + [5] flatten |
| 89 | + @ ~/.julia/dev/Flux/src/layers/stateless.jl:105 [inlined] |
| 90 | + [6] macro expansion |
| 91 | + @ ~/.julia/dev/Flux/src/layers/basic.jl:68 [inlined] |
| 92 | + [7] _applychain |
| 93 | + @ ~/.julia/dev/Flux/src/layers/basic.jl:68 |
| 94 | + [8] Chain |
| 95 | + @ ~/.julia/dev/Flux/src/layers/basic.jl:65 [inlined] |
| 96 | + [9] loss1 |
| 97 | + @ ./REPL[59]:4 [inlined] |
| 98 | + [10] loss1 |
| 99 | + @ ./REPL[59]:0 [inlined] |
| 100 | + [11] diffejulia_loss1_99632_inner_54wrap |
| 101 | + @ ./REPL[59]:0 |
| 102 | + [12] macro expansion |
| 103 | + @ ~/.julia/packages/Enzyme/haqjK/src/compiler.jl:5204 [inlined] |
| 104 | + [13] enzyme_call |
| 105 | + @ ~/.julia/packages/Enzyme/haqjK/src/compiler.jl:4750 [inlined] |
| 106 | + [14] CombinedAdjointThunk |
| 107 | + @ ~/.julia/packages/Enzyme/haqjK/src/compiler.jl:4622 [inlined] |
| 108 | + [15] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::Const{var"#loss1#11"}, ::Type{Active}, ::Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, ::Const{Array{Float32, 4}}) |
| 109 | + @ Enzyme ~/.julia/packages/Enzyme/haqjK/src/Enzyme.jl:503 |
| 110 | + [16] _enzyme_gradient(::Function, ::Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool) |
| 111 | + @ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:49 |
| 112 | + [17] _enzyme_gradient |
| 113 | + @ ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:44 [inlined] |
| 114 | + [18] gradient(::Function, ::Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, ::Const{Array{Float32, 4}}) |
| 115 | + @ Flux ~/.julia/dev/Flux/src/gradient.jl:122 |
| 116 | + [19] macro expansion |
| 117 | + @ REPL[59]:11 [inlined] |
| 118 | + [20] macro expansion |
| 119 | + @ /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/Test/src/Test.jl:1700 [inlined] |
| 120 | + [21] top-level scope |
| 121 | + @ REPL[59]:2 |
| 122 | + [22] eval |
| 123 | + @ ./boot.jl:430 [inlined] |
| 124 | + [23] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module) |
| 125 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:226 |
| 126 | + [24] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function) |
| 127 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:323 |
| 128 | + [25] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function) |
| 129 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:308 |
| 130 | + [26] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any) |
| 131 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:464 |
| 132 | + [27] run_repl(repl::REPL.AbstractREPL, consumer::Any) |
| 133 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:450 |
| 134 | + [28] (::Base.var"#1138#1140"{Bool, Symbol, Bool})(REPL::Module) |
| 135 | + @ Base ./client.jl:446 |
| 136 | + [29] #invokelatest#2 |
| 137 | + @ ./essentials.jl:1054 [inlined] |
| 138 | + [30] invokelatest |
| 139 | + @ ./essentials.jl:1051 [inlined] |
| 140 | + [31] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool) |
| 141 | + @ Base ./client.jl:430 |
| 142 | + [32] repl_main |
| 143 | + @ ./client.jl:567 [inlined] |
| 144 | + [33] _start() |
| 145 | + @ Base ./client.jl:541 |
| 146 | +Test Summary: | Pass Error Total Time |
| 147 | +simple gradient | 1 1 2 1.3s |
| 148 | +ERROR: Some tests did not pass: 1 passed, 0 failed, 1 errored, 0 broken. |
| 149 | +
|
| 150 | +=# |
| 151 | + |
| 152 | +@testset "simple train!" begin |
| 153 | + X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32) |
| 154 | + Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1) |
| 155 | + # data = Flux.DataLoader((X, Y); batchsize=16, shuffle=true) |
| 156 | + data = Flux.DataLoader((X .+ 0f0, Y .+ 0f0); batchsize=16, shuffle=true) # this avoids some erros from conversion |
| 157 | + |
| 158 | + model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |> Reactor |
| 159 | + state = Flux.setup(Adam(0.1, (0.7, 0.95)), model) # Note that I'm doing this after |> Reactor, ideally before would work too? |
| 160 | + |
| 161 | + Flux.train!(model, data, state; epochs=100) do m, x, y |
| 162 | + Flux.logitcrossentropy(m(x), y) |
| 163 | + end |
| 164 | + |
| 165 | + @test all((softmax(model(X)) .> 0.5) .== Y) |
| 166 | +end |
| 167 | + |
| 168 | +#= |
| 169 | +
|
| 170 | +[ Info: compiling |
| 171 | +simple train!: Error During Test at REPL[57]:1 |
| 172 | + Got exception outside of a @test |
| 173 | + type Array has no field data |
| 174 | + Stacktrace: |
| 175 | + [1] getproperty |
| 176 | + @ ./Base.jl:49 [inlined] |
| 177 | + [2] macro expansion |
| 178 | + @ ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:771 [inlined] |
| 179 | + [3] (::Reactant.Compiler.Thunk{Symbol("##_step!_reactant#1017757")})(::var"#9#10", ::Duplicated{ConcreteRArray{Float32, 1}}, ::Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, ::Tuple{Matrix{Float32}, Matrix{Float32}}, ::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}) |
| 180 | + @ Reactant.Compiler ~/.julia/packages/Reactant/sIJRJ/src/Compiler.jl:787 |
| 181 | + [4] macro expansion |
| 182 | + @ ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:332 [inlined] |
| 183 | + [5] macro expansion |
| 184 | + @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined] |
| 185 | + [6] train!(loss::Function, m::Reactor{Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random.TaskLocalRNG, Val{nothing}}, opt_state::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}; epochs::Int64) |
| 186 | + @ FluxReactantExt ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:324 |
| 187 | + [7] macro expansion |
| 188 | + @ REPL[57]:10 [inlined] |
| 189 | + [8] macro expansion |
| 190 | + @ /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/Test/src/Test.jl:1700 [inlined] |
| 191 | + [9] top-level scope |
| 192 | + @ REPL[57]:2 |
| 193 | + [10] eval |
| 194 | + @ ./boot.jl:430 [inlined] |
| 195 | + [11] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module) |
| 196 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:226 |
| 197 | + [12] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function) |
| 198 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:323 |
| 199 | + [13] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function) |
| 200 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:308 |
| 201 | + [14] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any) |
| 202 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:464 |
| 203 | + [15] run_repl(repl::REPL.AbstractREPL, consumer::Any) |
| 204 | + @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:450 |
| 205 | + [16] (::Base.var"#1138#1140"{Bool, Symbol, Bool})(REPL::Module) |
| 206 | + @ Base ./client.jl:446 |
| 207 | + [17] #invokelatest#2 |
| 208 | + @ ./essentials.jl:1054 [inlined] |
| 209 | + [18] invokelatest |
| 210 | + @ ./essentials.jl:1051 [inlined] |
| 211 | + [19] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool) |
| 212 | + @ Base ./client.jl:430 |
| 213 | + [20] repl_main |
| 214 | + @ ./client.jl:567 [inlined] |
| 215 | + [21] _start() |
| 216 | + @ Base ./client.jl:541 |
| 217 | +Test Summary: | Error Total Time |
| 218 | +simple train! | 1 1 14.3s |
| 219 | +ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken. |
| 220 | +
|
| 221 | +(jl_smIYmq) pkg> st |
| 222 | +Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_smIYmq/Project.toml` |
| 223 | + [587475ba] Flux v0.16.1 `~/.julia/dev/Flux` |
| 224 | + [3102ee7a] Fluxperimental v0.2.3 `~/.julia/dev/Fluxperimental` |
| 225 | + [3c362404] Reactant v0.2.10 |
| 226 | +
|
| 227 | +=# |
| 228 | + |
| 229 | +end # @testset "Reactant + Flux" |
0 commit comments