Skip to content

Commit 4155bcd

Browse files
authored
explicitly preserve eltype (#56)
1 parent adc0e85 commit 4155bcd

File tree

5 files changed

+26
-16
lines changed

5 files changed

+26
-16
lines changed

src/Optimisers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ The initial state is `init(rule::RuleType, parameters)`.
2424
2525
# Example
2626
```jldoctest
27-
julia> Optimisers.init(Descent(0.1), [1,2,3]) === nothing
27+
julia> Optimisers.init(Descent(0.1), Float32[1,2,3]) === nothing
2828
true
2929
30-
julia> Optimisers.apply!(Descent(0.1), nothing, [1,2,3], [4,5,6])
31-
(nothing, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, ([4, 5, 6], 0.1)))
30+
julia> Optimisers.apply!(Descent(0.1), nothing, Float32[1,2,3], [4,5,6])
31+
(nothing, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, ([4, 5, 6], 0.1f0)))
3232
```
3333
"""
3434
apply!
@@ -41,7 +41,7 @@ This and [`apply!`](@ref) are the two functions which any new optimisation rule
4141
4242
# Examples
4343
```jldoctest
44-
julia> Optimisers.init(Descent(), [1,2,3]) # is `nothing`
44+
julia> Optimisers.init(Descent(), Float32[1,2,3]) # is `nothing`
4545
4646
julia> Optimisers.init(Momentum(), [1.0, 2.0])
4747
2-element Vector{Float64}:

src/interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function setup(rule, x; seen = Base.IdSet())
2121
end
2222
end
2323

24-
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
24+
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
2525

2626
update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
2727
update!(::Nothing, x, x̄s...) = nothing, x
@@ -50,10 +50,10 @@ end
5050
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
5151

5252
isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. transposed shared weights
53-
isnumeric(x::AbstractArray{<:Bool}) = false # convention of ChainRules is that Bool is non-differentiable
53+
isnumeric(x::AbstractArray{<:Integer}) = false
5454
isnumeric(x) = false
5555

56-
iswriteable(::DenseArray{<:AbstractFloat}) = true # more elaborate versions are possible, wait until needed?
56+
iswriteable(::DenseArray) = true # more elaborate versions are possible, wait until needed?
5757
iswriteable(_) = false
5858

5959
"""

src/rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ This is equivalent to `Descent(1)`.
525525
```jldoctest
526526
julia> o = OptimiserChain(ClipGrad(1), Descent(0.1));
527527
528-
julia> m = ([0,0,0],);
528+
julia> m = (zeros(3),);
529529
530530
julia> s = Optimisers.setup(o, m)
531531
(Leaf(OptimiserChain(ClipGrad{Int64}(1), Descent{Float64}(0.1)), [nothing, nothing]),)

test/rules.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ end
118118
# Static version is truly out-of-place:
119119
mstatic = (SA{Float32}[1,2], SA{Float64}[3,4]) # , SA{Float16}[5,6]) with Float16, all fail
120120
upstatic = Optimisers.update(Optimisers.setup(o, mstatic), mstatic, mstatic)[2]
121-
if o isa OptimiserChain && o.opts[2] isa ADAM # These promote to Float64
122-
@test_broken map(eltype, upstatic) == types[1:2]
123-
else
124-
@test map(eltype, upstatic) == types[1:2]
125-
end
121+
@test map(eltype, upstatic) == types[1:2]
126122
@test upstatic[1] isa SVector
127123

128124
# With ordinary Array gradient, what happens? Not so important!

test/runtests.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
3838
end
3939

4040
@testset "gradient clipping" begin
41-
m == ([0], sin), γ = rand(3))
41+
m == ([0.0], sin), γ = rand(3))
4242
s1 = Optimisers.setup(ClipGrad(13), m)
4343
_, m1 = Optimisers.update(s1, m, (α = nothing, γ = [1,10,100],))
4444
@test m.γ .- m1.γ [1, 10, 13]
@@ -58,7 +58,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
5858
end
5959

6060
@testset "OptimiserChain" begin
61-
x = [1,10,100]; dx = [1,2,3];
61+
x = [1, 10, 100.0]; dx = [1, 2, 3.0];
6262
@test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-3]
6363
@test Optimisers.update(Optimisers.setup(ClipGrad(2), x), x, dx)[2] [1-1, 10-2, 100-2]
6464

@@ -81,7 +81,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
8181

8282
@testset "trainable subset" begin
8383
# Foo has an old-style tuple trainable, both elements
84-
mf = Foo([1,2], (a = sin, b = [3,4], c = 5))
84+
mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5))
8585
sf = Optimisers.setup(Descent(0.1), mf)
8686
gf = (x = nothing, y = (a = nothing, b = [1,1], c = 1))
8787
_, mf2 = Optimisers.update(sf, mf, gf)
@@ -116,6 +116,20 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
116116
@test Optimisers.update!(s, m, g...)[2] isa Foo
117117
end
118118

119+
@testset "eltype preservation" begin
120+
m = (Float16[1,2], Float32[3,4])
121+
s1 = Optimisers.setup(Descent(0.1), m)
122+
s2, m2 = Optimisers.update(s1, m, m)
123+
@test eltype(m2[1]) == Float16 # because update copies & calls update!
124+
@test eltype(m2[2]) == Float32
125+
126+
staticm = (SA{Float16}[1,2], SA{Float32}[3,4])
127+
s3 = Optimisers.setup(Descent(0.1), staticm)
128+
s4, m4 = Optimisers.update(s3, staticm, staticm)
129+
@test eltype(m4[1]) == Float16 # because of explicit broadcast in subtract!
130+
@test eltype(m4[2]) == Float32
131+
end
132+
119133
@testset "forgotten gradient" begin
120134
x = [1.0, 2.0]
121135
sx = Optimisers.setup(Descent(), x)

0 commit comments

Comments
 (0)