Skip to content

Mark OptimiserChain as @functor and improve inference for apply! #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 27, 2022

Conversation

jondeuce
Copy link
Contributor

I noticed that OptimiserChain was not marked @functor; this PR marks it as such to allow one to modify the rules internal to OptimiserChain via fmap.

I also modified the optimiser chain state to be a Tuple instead of an AbstractArray for better type inference in apply!(o::OptimiserChain, ...). If this is considered breaking/undesired I can remove it from the PR. Example of the improved inference:

x = zeros(Float32, 3)
dx = zero(x)
rule = AdamW()
state = Optimisers.setup(rule, x)
@code_warntype Optimisers.apply!(rule, state.state, x, dx)

Before:

MethodInstance for Optimisers.apply!(::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}, ::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, ::Vector{Float32}, ::Vector{Float32})
  from apply!(o::OptimiserChain, states, x, dx, dxs...) in Optimisers at /home/jdoucette/.julia/dev/Optimisers/src/rules.jl:623
Arguments
  #self#::Core.Const(Optimisers.apply!)
  o::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}
  states::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
  x::Vector{Float32}
  dx@_5::Vector{Float32}
  dxs::Tuple{}
Locals
  @_7::Union{Nothing, Tuple{Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}}
  new_states::Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
  @_9::Int64
  @_10::Int64
  @_11::Int64
  state::Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}
  opt::Union{Adam{Float32}, WeightDecay{Float32}}
  i::Int64
  dx@_15::Any
Body::Tuple{Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Any}
1 ─       (dx@_15 = dx@_5)
│         (new_states = Optimisers.similar(states))
│   %3  = Base.getproperty(o, :opts)::Tuple{Adam{Float32}, WeightDecay{Float32}}%4  = Optimisers.zip(%3, states)::Base.Iterators.Zip{Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}
│   %5  = Optimisers.enumerate(%4)::Base.Iterators.Enumerate{Base.Iterators.Zip{Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}
│         (@_7 = Base.iterate(%5))
│   %7  = (@_7::Union{Nothing, Tuple{Tuple{Int64, Tuple{Adam{Float32}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}} === nothing)::Bool%8  = Base.not_int(%7)::Bool
└──       goto #4 if not %8
2%10 = @_7::Tuple{Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}, Tuple{Int64, Tuple{Int64, Int64}}}
│   %11 = Core.getfield(%10, 1)::Tuple{Int64, Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}
│   %12 = Base.indexed_iterate(%11, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (i = Core.getfield(%12, 1))
│         (@_11 = Core.getfield(%12, 2))
│   %15 = Base.indexed_iterate(%11, 2, @_11::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Int64}, Any[Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Core.Const(3)])
│   %16 = Core.getfield(%15, 1)::Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}
│   %17 = Base.indexed_iterate(%16, 1)::Core.PartialStruct(Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Int64}, Any[Union{Adam{Float32}, WeightDecay{Float32}}, Core.Const(2)])
│         (opt = Core.getfield(%17, 1))
│         (@_10 = Core.getfield(%17, 2))
│   %20 = Base.indexed_iterate(%16, 2, @_10::Core.Const(2))::Core.PartialStruct(Tuple{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Int64}, Any[Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Core.Const(3)])
│         (state = Core.getfield(%20, 1))
│   %22 = Core.getfield(%10, 2)::Tuple{Int64, Tuple{Int64, Int64}}%23 = Core.tuple(opt, state, x, dx@_15)::Tuple{Union{Adam{Float32}, WeightDecay{Float32}}, Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Vector{Float32}, Any}
│   %24 = Core._apply_iterate(Base.iterate, Optimisers.apply!, %23, dxs)::Tuple{Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Any}
│   %25 = Base.indexed_iterate(%24, 1)::Core.PartialStruct(Tuple{Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Int64}, Any[Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}, Core.Const(2)])
│   %26 = Core.getfield(%25, 1)::Union{Nothing, Tuple{Any, Any, Tuple{Float32, Float32}}}
│         (@_9 = Core.getfield(%25, 2))
│   %28 = Base.indexed_iterate(%24, 2, @_9::Core.Const(2))::Tuple{Any, Int64}
│         (dx@_15 = Core.getfield(%28, 1))
│         Base.setindex!(new_states, %26, i)
│         (@_7 = Base.iterate(%5, %22))
│   %32 = (@_7 === nothing)::Bool%33 = Base.not_int(%32)::Bool
└──       goto #4 if not %33
3 ─       goto #2
4%36 = Core.tuple(new_states, dx@_15)::Tuple{Vector{Union{Nothing, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, Any}
└──       return %36

After:

MethodInstance for Optimisers.apply!(::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}, ::Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, ::Vector{Float32}, ::Vector{Float32})
  from apply!(o::OptimiserChain, states, x, dx, dxs...) in Optimisers at /home/jdoucette/.julia/dev/Optimisers/src/rules.jl:625
Arguments
  #self#::Core.Const(Optimisers.apply!)
  o::OptimiserChain{Tuple{Adam{Float32}, WeightDecay{Float32}}}
  states::Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}
  x::Vector{Float32}
  dx::Vector{Float32}
  dxs::Tuple{}
Locals
  #81::Optimisers.var"#81#82"{Vector{Float32}, Tuple{}}
Body::Tuple{Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}}
1%1  = Optimisers.:(var"#81#82")::Core.Const(Optimisers.var"#81#82")
│   %2  = Core.typeof(x)::Core.Const(Vector{Float32})
│   %3  = Core.typeof(dxs)::Core.Const(Tuple{})
│   %4  = Core.apply_type(%1, %2, %3)::Core.Const(Optimisers.var"#81#82"{Vector{Float32}, Tuple{}})
│         (#81 = %new(%4, x, dxs))%6  = #81::Optimisers.var"#81#82"{Vector{Float32}, Tuple{}}%7  = (:init,)::Core.Const((:init,))
│   %8  = Core.apply_type(Core.NamedTuple, %7)::Core.Const(NamedTuple{(:init,)})
│   %9  = ()::Core.Const(())
│   %10 = Core.tuple(%9, dx)::Tuple{Tuple{}, Vector{Float32}}%11 = Core.tuple(%10)::Tuple{Tuple{Tuple{}, Vector{Float32}}}%12 = (%8)(%11)::NamedTuple{(:init,), Tuple{Tuple{Tuple{}, Vector{Float32}}}}
│   %13 = Core.kwfunc(Optimisers.foldl)::Core.Const(Base.var"#foldl##kw"())
│   %14 = Base.getproperty(o, :opts)::Tuple{Adam{Float32}, WeightDecay{Float32}}%15 = Base.broadcasted(Optimisers.tuple, %14, states)::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(tuple), Tuple{Tuple{Adam{Float32}, WeightDecay{Float32}}, Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}}}
│   %16 = Base.materialize(%15)::Tuple{Tuple{Adam{Float32}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Tuple{WeightDecay{Float32}, Nothing}}
│   %17 = (%13)(%12, Optimisers.foldl, %6, %16)::Core.PartialStruct(Tuple{Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}}, Any[Tuple{Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}, Nothing}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}}, Any[Core.Const(+), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}}, Any[Core.Const(*), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Float32}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}, Any[Vector{Float32}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}, Any[Core.Const(-), Core.PartialStruct(Tuple{Int64, Float32}, Any[Core.Const(1), Float32]), Core.Const(nothing)])]), Core.Const(nothing)]), Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}}, Any[Core.Const(+), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Float32}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sqrt), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}}, Any[Core.Const(sqrt), Core.PartialStruct(Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}}, Any[Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(/), Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}}, Any[Core.Const(/), Core.PartialStruct(Tuple{Vector{Float32}, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}}, Any[Vector{Float32}, Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float32}}, Any[Core.Const(-), Core.PartialStruct(Tuple{Int64, Float32}, Any[Core.Const(1), Float32]), Core.Const(nothing)])]), Core.Const(nothing)])]), Core.Const(nothing)]), Float32]), Core.Const(nothing)])]), Core.Const(nothing)]), Float32]), Tuple{Base.OneTo{Int64}}]), Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Float32, Vector{Float32}}}]), Tuple{Base.OneTo{Int64}}])])
└──       return %17

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty reasonable to me. I recall we had some discussion about OptimiserChain's using a vector for state instead of a tuple, but not the conclusion. @darsnack might know.

@@ -32,13 +32,13 @@ julia> fieldnames(Adam)
(:eta, :beta, :epsilon)

julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), (nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))), fun = ())
Copy link
Member

@ToucheSir ToucheSir Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't have expected fun = nothing to change to fun = () with this PR. Is this the output you see locally? If so, can you change this codeblock to a jldoctest so that we catch changes like this in the future?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed by #106 I think.

I guess we just need to remember the syntax for naming a doctest block, so that this one uses m, st from the one before.

Copy link
Contributor Author

@jondeuce jondeuce Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that change does not arise from this PR, sorry for the confusion. I changed nothing to () for consistency with the doctest above it; I can lookup the syntax to make that block into a proper doctest, as well @mcabbott.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do this, I think.

@mcabbott mcabbott merged commit 0b2d32b into FluxML:master Nov 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants