@@ -25,7 +25,7 @@ init(o::Descent, x::AbstractArray) = nothing
25
25
26
26
function apply! (o:: Descent , state, x, dx)
27
27
η = convert (float (eltype (x)), o. eta)
28
-
28
+
29
29
return state, @lazy dx * η # @lazy creates a Broadcasted, will later fuse with x .= x .- dx
30
30
end
31
31
@@ -51,7 +51,7 @@ init(o::Momentum, x::AbstractArray) = zero(x)
51
51
function apply! (o:: Momentum , state, x, dx)
52
52
η, ρ, mvel = o. eta, o. rho, state
53
53
@. . mvel = ρ * mvel + η * dx # Macro @.. broadcasts into mvel if it can, else @. of rhs.
54
-
54
+
55
55
return mvel, mvel
56
56
end
57
57
@@ -79,7 +79,7 @@ function apply!(o::Nesterov, state, x, dx)
79
79
80
80
newdx = @. - ρ^ 2 * vel + (1 + ρ) * η * dx # Cannot be lazy as this needs the old velocity
81
81
@. . vel = ρ * vel - η * dx
82
-
82
+
83
83
return vel, newdx
84
84
end
85
85
@@ -125,7 +125,7 @@ function apply!(o::RMSProp, state, x, dx)
125
125
@. . lin = ρ * lin + (1 - ρ) * dx
126
126
end
127
127
dx′ = @lazy dx * η / (sqrt (quad - abs2 (lin)) + ϵ)
128
-
128
+
129
129
return (quad, lin), dx′
130
130
end
131
131
@@ -152,7 +152,7 @@ learning algorithm that depends only on the sign of the gradient.
152
152
# Parameters
153
153
- Learning rate (`η`): Amount by which gradients are discounted before updating
154
154
the weights.
155
-
155
+
156
156
- Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
157
157
158
158
- Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
@@ -168,14 +168,16 @@ Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = Rprop{typeof(η)}(η,
168
168
init (o:: Rprop , x:: AbstractArray ) = (zero (x), onevalue (o. eta, x))
169
169
170
170
function apply! (o:: Rprop , state, x, dx)
171
- ℓ, Γ = o. ell, o. gamma
171
+ T = eltype (x)
172
+ ℓ = map (T, o. ell)
173
+ Γ = map (T, o. gamma)
172
174
g, η = state
173
175
174
176
η = broadcast (g, η, dx) do g, η, dx
175
177
g * dx > 0 ? min (η * ℓ[2 ], Γ[2 ]) : g * dx < 0 ? max (η * ℓ[1 ], Γ[1 ]) : η
176
178
end
177
179
g = broadcast (g, dx) do g, dx
178
- g * dx < 0 ? zero (dx ) : dx
180
+ g * dx < 0 ? zero (T ) : T (dx)
179
181
end
180
182
dx′ = @lazy η * sign (g)
181
183
@@ -384,7 +386,7 @@ function apply!(o::AdaDelta, state, x, dx)
384
386
# DON'T remove epsilon from numerator or even out of the square roots!
385
387
dx′ = @. dx * sqrt (Δacc + ϵ) / sqrt (acc + ϵ) # Cannot be lazy as this needs the old Δacc
386
388
@. . Δacc = ρ * Δacc + (1 - ρ) * abs2 (dx′)
387
-
389
+
388
390
return (acc, Δacc), dx′
389
391
end
390
392
@@ -454,7 +456,7 @@ function apply!(o::NAdam, state, x, dx)
454
456
455
457
@. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
456
458
@. . vt = β[2 ] * vt + (1 - β[2 ]) * abs2 (dx)
457
- dx′ = @lazy (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
459
+ dx′ = @lazy (β[1 ] * mt / (1 - β[1 ] * βt[1 ]) + (1 - β[1 ]) * dx / (1 - βt[1 ])) /
458
460
(sqrt (vt * β[2 ] / (1 - βt[2 ])) + ϵ) * η
459
461
460
462
return (mt, vt, βt .* β), dx′
@@ -508,7 +510,7 @@ function apply!(o::AdaBelief, state, x, dx)
508
510
@. . mt = β[1 ] * mt + (1 - β[1 ]) * dx
509
511
@. . st = β[2 ] * st + (1 - β[2 ]) * abs2 (dx - mt) + ϵ
510
512
dx′ = @lazy η * mt / (1 - βt[1 ]) / (sqrt (st / (1 - βt[2 ])) + ϵ)
511
-
513
+
512
514
return (mt, st, βt .* β), dx′
513
515
end
514
516
0 commit comments