Skip to content

Commit d5a374b

Browse files
YichengDWumcabbott
andauthored
fix type instability of Rprop (#103)
* fix type instability * Update rules.jl * fix * better style * add testing * Update test/runtests.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent 31267ab commit d5a374b

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

src/rules.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ init(o::Descent, x::AbstractArray) = nothing
2525

2626
function apply!(o::Descent, state, x, dx)
2727
η = convert(float(eltype(x)), o.eta)
28-
28+
2929
return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
3030
end
3131

@@ -51,7 +51,7 @@ init(o::Momentum, x::AbstractArray) = zero(x)
5151
function apply!(o::Momentum, state, x, dx)
5252
η, ρ, mvel = o.eta, o.rho, state
5353
@.. mvel = ρ * mvel + η * dx # Macro @.. broadcasts into mvel if it can, else @. of rhs.
54-
54+
5555
return mvel, mvel
5656
end
5757

@@ -79,7 +79,7 @@ function apply!(o::Nesterov, state, x, dx)
7979

8080
newdx = @. - ρ^2 * vel + (1+ρ) * η * dx # Cannot be lazy as this needs the old velocity
8181
@.. vel = ρ * vel - η * dx
82-
82+
8383
return vel, newdx
8484
end
8585

@@ -125,7 +125,7 @@ function apply!(o::RMSProp, state, x, dx)
125125
@.. lin = ρ * lin + (1 - ρ) * dx
126126
end
127127
dx′ = @lazy dx * η / (sqrt(quad - abs2(lin)) + ϵ)
128-
128+
129129
return (quad, lin), dx′
130130
end
131131

@@ -152,7 +152,7 @@ learning algorithm that depends only on the sign of the gradient.
152152
# Parameters
153153
- Learning rate (`η`): Amount by which gradients are discounted before updating
154154
the weights.
155-
155+
156156
- Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
157157
158158
- Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
@@ -168,14 +168,16 @@ Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = Rprop{typeof(η)}(η,
168168
init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
169169

170170
function apply!(o::Rprop, state, x, dx)
171-
ℓ, Γ = o.ell, o.gamma
171+
T = eltype(x)
172+
= map(T, o.ell)
173+
Γ = map(T, o.gamma)
172174
g, η = state
173175

174176
η = broadcast(g, η, dx) do g, η, dx
175177
g * dx > 0 ? min* ℓ[2], Γ[2]) : g * dx < 0 ? max* ℓ[1], Γ[1]) : η
176178
end
177179
g = broadcast(g, dx) do g, dx
178-
g * dx < 0 ? zero(dx) : dx
180+
g * dx < 0 ? zero(T) : T(dx)
179181
end
180182
dx′ = @lazy η * sign(g)
181183

@@ -384,7 +386,7 @@ function apply!(o::AdaDelta, state, x, dx)
384386
# DON'T remove epsilon from numerator or even out of the square roots!
385387
dx′ = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ) # Cannot be lazy as this needs the old Δacc
386388
@.. Δacc = ρ * Δacc + (1 - ρ) * abs2(dx′)
387-
389+
388390
return (acc, Δacc), dx′
389391
end
390392

@@ -454,7 +456,7 @@ function apply!(o::NAdam, state, x, dx)
454456

455457
@.. mt = β[1] * mt + (1 - β[1]) * dx
456458
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
457-
dx′ = @lazy (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
459+
dx′ = @lazy (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
458460
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
459461

460462
return (mt, vt, βt .* β), dx′
@@ -508,7 +510,7 @@ function apply!(o::AdaBelief, state, x, dx)
508510
@.. mt = β[1] * mt + (1 - β[1]) * dx
509511
@.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt) + ϵ
510512
dx′ = @lazy η * mt / (1 - βt[1]) / (sqrt(st / (1 - βt[2])) + ϵ)
511-
513+
512514
return (mt, st, βt .* β), dx′
513515
end
514516

test/runtests.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
2222
g = ([25, 33],)
2323
o = Descent(0.1)
2424
s = Optimisers.setup(o, m)
25-
25+
2626
s2, m2 = Optimisers.update(s, m, g)
2727
@test m[1] == 1:2 # not mutated
2828
@test Optimisers.maywrite(m[1])
@@ -129,8 +129,19 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
129129
s4, m4 = Optimisers.update(s3, staticm, staticm)
130130
@test eltype(m4[1]) == Float16 # because of explicit broadcast in subtract!
131131
@test eltype(m4[2]) == Float32
132+
133+
# Rprop re-creates its state arrays, check they don't get widened:
134+
s5 = Optimisers.setup(Rprop(0.1), m) # Float64 rule
135+
grad64 = ([1.0,2.0], SA[3.0,4.0]) # Float64 gradients
136+
s6, m6 = Optimisers.update(s5, m, grad64)
137+
@test eltype(m6[1]) == Float16
138+
@test eltype(m6[2]) == Float32
139+
@test eltype(s6[1].state[1]) == Float16
140+
@test eltype(s6[1].state[2]) == Float16
141+
@test eltype(s6[2].state[1]) == Float32
142+
@test eltype(s6[2].state[2]) == Float32
132143
end
133-
144+
134145
@testset "adjusyting parameters" begin
135146
# Simple momentum:
136147
m == ([0.0], sin), γ = Float32[4,3,2])
@@ -139,25 +150,25 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
139150
@test m.γ .- m1.γ [0.1, 1, 10]
140151
@test s1.γ.rule.eta == 0.1
141152
@test s1.γ.state [0.1, 1, 10]
142-
153+
143154
s2 = Optimisers.adjust(s1, 0.2)
144155
@test s2.γ.rule.eta == 0.2
145156
@test s2.γ.rule.rho == 0.9
146157
@test s2.γ.state == s1.γ.state
147158
@test s2.α[1].rule.eta == 0.2
148159
@test s2.α[1].state == s1.α[1].state
149-
160+
150161
s3 = Optimisers.adjust(s1; eta=0.3, rho=0.7)
151162
@test s3.γ.rule.eta == 0.3
152163
@test s3.γ.rule.rho == 0.7
153164
@test s3.γ.state == s1.γ.state
154165
@test s3.α[1].rule.rho == 0.7
155-
166+
156167
_, m3 = Optimisers.update(s3, m, (α = nothing, γ = [1,10,100],))
157168
@test !(m.γ .- m3.γ [1, 10, 100])
158169

159170
@test s1 == Optimisers.adjust(s1, zeta = "this does nothing")
160-
171+
161172
# OptimiserChain
162173
sc = Optimisers.setup(OptimiserChain(ClipGrad(2), Adam()), m)
163174
sc1, mc1 = Optimisers.update(sc, m, (α = nothing, γ = [1,10,100],))
@@ -168,7 +179,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
168179
@test sc2.γ.rule.opts[1].delta == 2 # unchanged
169180
@test sc2.γ.rule.opts[2].eta === 0.2f0
170181
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
171-
182+
172183
sc2 = Optimisers.adjust(sc1; delta = 2.5) # ClipGrad(2) does not store an Int, for this reason
173184
@test sc2.γ.rule.opts[1].delta == 2.5
174185
@test sc2.γ.rule.opts[2].eta === 0.001f0 # unchanged

0 commit comments

Comments
 (0)