Skip to content

Commit da60ea5

Browse files
authored
fix Flot64 beta in Adam etc (#158)
1 parent 6a4f948 commit da60ea5

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/rules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ end
206206
epsilon = 1e-8
207207
end
208208

209-
init(o::Adam, x::AbstractArray) = (zero(x), zero(x), o.beta)
209+
init(o::Adam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
210210

211211
function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
212212
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
@@ -267,7 +267,7 @@ end
267267
epsilon = 1e-8
268268
end
269269

270-
init(o::RAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, 1)
270+
init(o::RAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), 1)
271271

272272
function apply!(o::RAdam, state, x::AbstractArray{T}, dx) where T
273273
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
@@ -307,7 +307,7 @@ end
307307
epsilon = 1e-8
308308
end
309309

310-
init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)
310+
init(o::AdaMax, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
311311

312312
function apply!(o::AdaMax, state, x::AbstractArray{T}, dx) where T
313313
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
@@ -340,7 +340,7 @@ is a variant of Adam adding an "optimistic" term suitable for adversarial traini
340340
epsilon = 1e-8
341341
end
342342

343-
init(o::OAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
343+
init(o::OAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), zero(x))
344344

345345
function apply!(o::OAdam, state, x::AbstractArray{T}, dx) where T
346346
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
@@ -471,7 +471,7 @@ Parameters don't need tuning.
471471
epsilon = 1e-8
472472
end
473473

474-
init(o::NAdam, x::AbstractArray) = (zero(x), zero(x), o.beta)
474+
init(o::NAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
475475

476476
function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
477477
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
@@ -524,7 +524,7 @@ Adam optimiser.
524524
epsilon = 1e-16
525525
end
526526

527-
init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x), o.beta)
527+
init(o::AdaBelief, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))
528528

529529
function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
530530
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)

0 commit comments

Comments
 (0)