diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a067034cee..94df9ddec4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,7 +65,7 @@ jobs: - uses: codecov/codecov-action@v5 if: matrix.version == '1' && matrix.os == 'ubuntu-latest' with: - file: lcov.info + files: lcov.info docs: name: Documentation diff --git a/NEWS.md b/NEWS.md index db05c18067..1d0b547b7f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,9 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change. The module is still available for now, but will be removed in a future release. * Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible. +* `Flux.params` has been deprecated. Use Zygote's explicit differentiation instead, +`gradient(m -> loss(m, x, y), model)`, or use `Flux.trainables(model)` to get the trainable parameters. +* Flux now requires Functors.jl v0.5. This new release of Functors assumes all types to be functors by default. Therefore, applying `@layer` or `@functor` to a type is no longer strictly necessary for Flux's models. However, it is still recommended to use `@layer Model` for additional functionality like pretty printing. ## v0.14.22 * Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl). diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 57d87c57b7..ff4d464b25 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -12,11 +12,6 @@ Metal GPU acceleration is available on Apple Silicon hardware. For more details In order to trigger GPU support in Flux, you need to call `using CUDA`, `using AMDGPU` or `using Metal` in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not required, but the package has to be installed in the environment. - -!!! compat "Flux ≤ 0.13" - Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually. - - ## Basic GPU Usage Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl). diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md index 5d5fca413a..3bb4358afe 100644 --- a/docs/src/guide/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -226,7 +226,7 @@ m(5) # => 26 ## Layer Helpers -There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref Flux.train!) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: +We can give our layer some additional functionality, like nice printing, using the [`@layer`](@ref Flux.@layer) macro: ```julia Flux.@layer Affine diff --git a/docs/src/guide/models/custom_layers.md b/docs/src/guide/models/custom_layers.md index 01942c5ea2..723016dd00 100644 --- a/docs/src/guide/models/custom_layers.md +++ b/docs/src/guide/models/custom_layers.md @@ -18,7 +18,7 @@ function (m::CustomModel)(x) return m.chain(x) + x end -# Call @layer to allow for training. Described below in more detail. +# This is optional but recommended for pretty printing and other niceties Flux.@layer CustomModel ``` Notice that we parameterized the type of the `chain` field. This is necessary for fast Julia code, so that that struct field can be given a concrete type. `Chain`s have a type parameter fully specifying the types of the layers they contain. By using a type parameter, we are freeing Julia to determine the correct concrete type, so that we do not need to specify the full, possibly quite long, type ourselves. @@ -78,7 +78,7 @@ The exact same method of `trainable` can also be defined using the macro, for co Flux.@layer Affine trainable=(W,) ``` -There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. +There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. ## Custom multiple input or output layer @@ -87,7 +87,7 @@ Sometimes a model needs to receive several separate inputs at once or produce se We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block, e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models. -This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. +This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. It may be the case that using the layers described as follows makes the definition of your model harder to read and to change. In that case, consider using the simpler approach of defining a custom structure described above. ### Multiple inputs: a custom `Join` layer diff --git a/docs/src/guide/models/quickstart.md b/docs/src/guide/models/quickstart.md index 664f56ff04..a0c92e0ef3 100644 --- a/docs/src/guide/models/quickstart.md +++ b/docs/src/guide/models/quickstart.md @@ -5,48 +5,53 @@ If you have used neural networks before, then this simple example might be helpf If you haven't, then you might prefer the [Fitting a Straight Line](overview.md) page. ```julia -# This will prompt if neccessary to install everything, including CUDA: +# This will prompt if neccessary to install everything, including CUDA. +# For CUDA acceleration, also cuDNN.jl has to be installed in your environment. using Flux, CUDA, Statistics, ProgressMeter # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} +# Use this object to move data and model to the GPU, if available +device = gpu_device() + # Define our model, a multi-layer perceptron with one hidden layer of size 3: model = Chain( - Dense(2 => 3, tanh), # activation function inside layer + Dense(2 => 3, tanh), # activation function inside layer BatchNorm(3), - Dense(3 => 2)) |> gpu # move model to GPU, if available + Dense(3 => 2)) |> device # move model to GPU, if available # The model encapsulates parameters, randomly initialised. Its initial output is: -out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32} -probs1 = softmax(out1) # normalise to get probabilities +out1 = model(noisy |> device) |> cpu # 2×1000 Matrix{Float32} +probs1 = softmax(out1) # normalise to get probabilities # To train the model, we use batches of 64 samples, and one-hot encoding: target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix -loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true); -# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) +loader = Flux.DataLoader((noisy, target), batchsize=64, shuffle=true); -optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. +opt_state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc. # Training loop, using the whole data set 1000 times: losses = [] @showprogress for epoch in 1:1_000 for (x, y) in loader + x, y = device((x, y)) loss, grads = Flux.withgradient(model) do m # Evaluate model and loss inside gradient context: y_hat = m(x) Flux.logitcrossentropy(y_hat, y) end - Flux.update!(optim, model, grads[1]) + Flux.update!(opt_state, model, grads[1]) push!(losses, loss) # logging, outside gradient context end end -optim # parameters, momenta and output have all changed -out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) -probs2 = softmax(out2) # normalise to get probabilities -mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far! +opt_state # parameters, momenta and output have all changed + +out2 = model(noisy |> device) |> cpu # first row is prob. of true, second row p(false) +probs2 = softmax(out2) # normalise to get probabilities +mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far! ``` ![](../../assets/quickstart/oneminute.png) @@ -95,9 +100,13 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. ```julia for epoch in 1:1_000 - Flux.train!(model, loader, optim) do m, x, y + Flux.train!(model, loader, opt_state) do m, x, y + x, y = device((x, y)) y_hat = m(x) Flux.logitcrossentropy(y_hat, y) end end ``` + +* In our simple example, we conveniently created the model has a [`Chain`](@ref Flux.Chain) of layers. +For more complex models, you can define a custom struct `MyModel` containing layers and arrays and implement the call operator `(::MyModel)(x) = ...` to define the forward pass. This is all it is needed for Flux to work. Marking the struct with [`Flux.@layer`](@ref) will add some more functionality, like pretty printing and the ability to mark some internal fields as trainable or not (also see [`trainable`](@ref Optimisers.trainable)). diff --git a/docs/src/reference/models/functors.md b/docs/src/reference/models/functors.md index 1637a7b8a6..1768c99f93 100644 --- a/docs/src/reference/models/functors.md +++ b/docs/src/reference/models/functors.md @@ -4,7 +4,7 @@ CollapsedDocStrings = true # Recursive transformations from Functors.jl -Flux models are deeply nested structures, and [Functors.jl](https://github.com/FluxML/Functors.jl) provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them. +Flux models are deeply nested structures, and [Functors.jl](https://github.com/FluxML/Functors.jl) provides tools needed to explore such objects, apply functions to the parameters they contain (e.g. for moving them to gpu), and re-build them. !!! compat "Flux ≤ 0.14" All layers were previously defined with the `Functors.@functor` macro. @@ -12,6 +12,9 @@ Flux models are deeply nested structures, and [Functors.jl](https://github.com/F Both allow [`Flux.setup`](@ref Flux.setup) to see the parameters inside, and [`gpu`](@ref) to move them to the GPU, but [`Flux.@layer`](@ref Flux.@layer) also overloads printing, and offers a way to define `trainable` at the same time. +!!! compat "Functors 0.5" + With Functors.jl v0.5, which is required by Flux v0.15 and later, every custom type is a functor by default. This means that applying `Flux.@layer` to a type is no longer strictly necessary, but it is still recommended for addictional features like pretty-printing and `trainable`. + `Functors.jl` has its own [notes on basic usage](https://fluxml.ai/Functors.jl/stable/#Basic-Usage-and-Implementation) for more details. Additionally, the [Advanced Model Building and Customisation](@ref man-advanced) page covers the use cases of `Functors` in greater details. ```@docs diff --git a/perf/recurrent.jl b/perf/recurrent.jl index 1550009bd3..d60d4912fb 100644 --- a/perf/recurrent.jl +++ b/perf/recurrent.jl @@ -3,7 +3,6 @@ struct RNNWrapper{T} rnn::T end -Flux.@functor RNNWrapper # Need to specialize for RNNWrapper. fw(r::RNNWrapper, X::Vector{<:AbstractArray}) = begin diff --git a/src/Flux.jl b/src/Flux.jl index 0ddea4a764..189db6d6c7 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -92,7 +92,9 @@ include("train.jl") using .Train using .Train: setup -using Adapt, Functors, OneHotArrays +using Adapt, OneHotArrays +using Functors: Functors, fmap, fmapstructure + include("utils.jl") include("functor.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 1b85345f5d..ab8f711f89 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -64,17 +64,6 @@ const FluxMetalAdaptor = MetalDevice ######## v0.15 deprecations ######################### -# Enable these when 0.16 is released, and delete const ClipGrad = Optimise.ClipValue etc: -# Base.@deprecate_binding Optimiser OptimiserChain -# Base.@deprecate_binding ClipValue ClipGrad - -# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( -# """On Flux 0.16, `train!` no longer accepts implicit `Zygote.Params`. -# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` -# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` -# where `loss_mxy` accepts the model as its first argument. -# """ -# )) function reset!(x) Base.depwarn("reset!(m) is deprecated. You can remove this call as it is no more needed.", :reset!) @@ -87,7 +76,6 @@ function params!(p::Zygote.Params, x, seen = IdSet()) elseif x in seen nothing else - _check_new_macro(x) # complains if you used @functor not @layer push!(seen, x) for child in trainable(x) params!(p, child, seen) @@ -126,3 +114,19 @@ function Optimisers.update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tu `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`. """) end + + +### v0.16 deprecations #################### + + +# Enable these when 0.16 is released, and delete const ClipGrad = Optimise.ClipValue etc: +# Base.@deprecate_binding Optimiser OptimiserChain +# Base.@deprecate_binding ClipValue ClipGrad + +# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( +# """On Flux 0.16, `train!` no longer accepts implicit `Zygote.Params`. +# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` +# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` +# where `loss_mxy` accepts the model as its first argument. +# """ +# )) diff --git a/src/functor.jl b/src/functor.jl index e8c02b919f..e049959d09 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -1,9 +1,3 @@ -import Adapt: adapt, adapt_storage -using LinearAlgebra: Cholesky -using Zygote: IdSet -import Functors: Functors, @functor, functor, fmap, isleaf -using SparseArrays: AbstractSparseArray - """ testmode!(model, [mode]) -> model @@ -85,7 +79,7 @@ end cpu(m) Copies `m` onto the CPU, the opposite of [`gpu`](@ref). -Recurses into structs marked [`@functor`](@ref). +Recurses into structs (thanks to Functors.jl). # Example ```julia-repl @@ -125,16 +119,14 @@ end Copies `m` to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time). - -On arrays, this calls CUDA's `cu`, which also changes arrays -with Float64 elements to Float32 while copying them to the device (same for AMDGPU). -To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref). +It recurses into structs according to Functors.jl. Use [`cpu`](@ref) to copy back to ordinary `Array`s. See also [`f32`](@ref) and [`f16`](@ref) to change element type only. -See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/) -to help identify the current device. +This function is just defined for convenience around [`gpu_device`](@ref), +and is equivalent to `gpu_device()(m)`. +You may consider defining `device = gpu_device()` once and then using `device(m)` to move data. # Example ```julia-repl @@ -153,10 +145,6 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer} """ gpu(x) = gpu_device()(x) -# TODO remove after https://github.com/LuxDL/Lux.jl/pull/1089 -ChainRulesCore.@non_differentiable gpu_device() -ChainRulesCore.@non_differentiable gpu_device(::Any) - # Precision struct FluxEltypeAdaptor{T} end @@ -222,10 +210,6 @@ Chain( """ f16(m) = _paramtype(Float16, m) -# Functors for certain Julia data structures -- PIRACY, should move to Functors.jl -@functor Cholesky -trainable(c::Cholesky) = () - """ gpu(data::DataLoader) diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 9f9d0435ec..56ead6dbcf 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -4,12 +4,8 @@ @layer :expand Chain @layer BatchNorm trainable=(β,γ) -This macro replaces most uses of `@functor`. Its basic purpose is the same: -When you define a new layer, this tells Flux to explore inside it -to see the parameters it trains, and also to move them to the GPU, change precision, etc. - -Like `@functor`, this assumes your struct has the default constructor, to enable re-building. -If you define an inner constructor (i.e. a function within the `struct` block) things may break. +This macro adds convenience functionality to a custom type to serve +as a neural network layer, module, or entire model. The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. @@ -30,15 +26,9 @@ julia> struct Trio; a; b; c end julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4)) Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) -julia> Flux.destructure(tri) # parameters are not yet visible to Flux -(Bool[], Restructure(Trio, ..., 0)) - julia> Flux.@layer :expand Trio -julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too -([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4)) - -julia> tri # and layer is printed like Chain +julia> tri # now the layer is printed like Chain Trio( Dense(2 => 1, tanh), # 3 parameters Dense(1 => 1; bias=false), # 1 parameters @@ -48,6 +38,10 @@ Trio( """ macro layer(exs...) + _layer_macro(exs...) +end + +function _layer_macro(exs...) out = quote end # These functions are defined in show.jl, and each return an expression overloading Base.show @@ -62,11 +56,8 @@ macro layer(exs...) push!(out.args, _macro_layer_show(esc(exs[1]))) exs end - - # This function exists only for depwarns when you use @functor directly - push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) - push!(out.args, _macro_functor(esc(type))) + push!(out.args, _macro_adapt(esc(type))) for j in 1:length(rest) ex = rest[j] @@ -82,61 +73,16 @@ macro layer(exs...) push!(out.args, _macro_trainable(esc(type), name, ex.args[2])) end - out + return out end -# Temporary depwarn function, called within `params`, is also called by `show`. - -function _check_new_macro(x::T) where T - Functors.isleaf(x) && return - Base.depwarn(LazyString("This type should probably now use `Flux.@layer` instead of `@functor`: ", T), Symbol("@functor")) -end -_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users -_check_new_macro(::NamedTuple) = nothing -_check_new_macro(::AbstractArray) = nothing -_check_new_macro(::Ref) = nothing - -# @layer's code for Functors & Adapt -# Unlike @functor, _default_functor doesn't need to eval anything - -function _macro_functor(type) +# @layer's code for Adapt +function _macro_adapt(type) quote - Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x) Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) end end -function _macro_functor(type, fields) - Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") - symbols = Tuple(map(_noquotenode, fields.args)) - quote - Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols)) - Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) - end -end -_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma - -function _default_functor(::Type{T}, x) where {T} - if @generated - F = fieldnames(T) - args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) - C = Base.typename(T).wrapper # constructor - # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) - recon = :(Base.splat($C)) - :((NamedTuple{$F}(($(args...),)), $recon)) - else - # Getting this parameterless type takes about 2μs, every time: - # spl = VERSION > v"1.9-" ? Splat : Base.splat - spl = Base.splat - namedtuple(x), spl(Base.typename(T).wrapper) - end -end - -function namedtuple(x::T) where T - F = fieldnames(T) - NamedTuple{F}(map(sy -> getfield(x, sy), F)) -end - # @layer's code for Optimisers.trainable, and perhaps anything else, # with the pattern that keywords mean function names & what fields they pick. diff --git a/src/layers/show.jl b/src/layers/show.jl index f3fc170ec5..b68340886d 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -20,13 +20,13 @@ function _macro_big_show(ex) end function _big_show(io::IO, obj, indent::Int=0, name=nothing) - pre, post = _show_pre_post(obj) children = _show_children(obj) if all(_show_leaflike, children) # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids, # but once all layers use @layer, they stop the recursion by defining a method for _big_show. _layer_show(io, obj, indent, name) else + pre, post = _show_pre_post(obj) println(io, " "^indent, isnothing(name) ? "" : "$name = ", pre) if obj isa Chain{<:NamedTuple} || obj isa NamedTuple # then we insert names -- can this be done more generically? @@ -66,7 +66,7 @@ _show_pre_post(obj) = string(nameof(typeof(obj)), "("), ")" _show_pre_post(::AbstractVector) = "[", "]" _show_pre_post(::NamedTuple) = "(;", ")" -_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: +_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error _show_leaflike(::Tuple{Vararg{Number}}) = true # e.g. stride of Conv @@ -146,7 +146,7 @@ function _big_finale(io::IO, m) end _childarray_sum(f, x::AbstractArray{<:Number}) = f(x) -_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x), +_childarray_sum(f, x) = Functors.isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x), init=0) # utility functions diff --git a/src/outputsize.jl b/src/outputsize.jl index c413405048..d45e3d7f6f 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -284,7 +284,7 @@ function (l::LazyLayer)(x::AbstractArray, ys::AbstractArray...) end function striplazy(m) - fs, re = functor(m) + fs, re = Functors.functor(m) re(map(striplazy, fs)) end function striplazy(l::LazyLayer) diff --git a/src/utils.jl b/src/utils.jl index 20d16596ed..f5b2f2c337 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -609,7 +609,7 @@ end Return an iterator over non-leaf objects that can be reached by recursing `m` over -the children given by [`functor`](@ref). +the children given by [`Functors.functor`](@ref). Useful for applying a function (e.g. a regularizer) over specific modules or subsets of the parameters diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 282d08911d..7334e31f59 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -76,7 +76,6 @@ end end SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ) (m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias) - @functor SimpleDense model = SimpleDense(2, 4) x = randn(Float32, 2) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 666e3e761b..5071e57e95 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -293,7 +293,6 @@ using Flux: activations x end (l::L1)(x) = l.x * x - Flux.@functor L1 Base.:*(a::AbstractArray, b::Input) = a * b.x par = Parallel(+, L1(rand(Float32, 3,3)), L1(rand(Float32, 3,3))) diff --git a/test/layers/macro.jl b/test/layers/macro.jl index e41d5a2240..53585fb427 100644 --- a/test/layers/macro.jl +++ b/test/layers/macro.jl @@ -37,7 +37,13 @@ end m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) @test m23re isa MacroTest.TwoThirds - @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) + + function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) + end + + @test namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) @test Optimisers.trainable(m23) == (a = [1 2],) diff --git a/test/utils.jl b/test/utils.jl index 8372e33b69..eba026bf2b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -623,7 +623,6 @@ end paths::T end Split(paths...) = Split(paths) - Flux.@functor Split (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) n_input, n_batch, n_shared = 5, 13, 11