|
206 | 206 | epsilon = 1e-8
|
207 | 207 | end
|
208 | 208 |
|
209 |
| -init(o::Adam, x::AbstractArray) = (zero(x), zero(x), o.beta) |
| 209 | +init(o::Adam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) |
210 | 210 |
|
211 | 211 | function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
|
212 | 212 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
|
267 | 267 | epsilon = 1e-8
|
268 | 268 | end
|
269 | 269 |
|
270 |
| -init(o::RAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, 1) |
| 270 | +init(o::RAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), 1) |
271 | 271 |
|
272 | 272 | function apply!(o::RAdam, state, x::AbstractArray{T}, dx) where T
|
273 | 273 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
|
307 | 307 | epsilon = 1e-8
|
308 | 308 | end
|
309 | 309 |
|
310 |
| -init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta) |
| 310 | +init(o::AdaMax, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) |
311 | 311 |
|
312 | 312 | function apply!(o::AdaMax, state, x::AbstractArray{T}, dx) where T
|
313 | 313 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
@@ -340,7 +340,7 @@ is a variant of Adam adding an "optimistic" term suitable for adversarial traini
|
340 | 340 | epsilon = 1e-8
|
341 | 341 | end
|
342 | 342 |
|
343 |
| -init(o::OAdam, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x)) |
| 343 | +init(o::OAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), zero(x)) |
344 | 344 |
|
345 | 345 | function apply!(o::OAdam, state, x::AbstractArray{T}, dx) where T
|
346 | 346 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
@@ -471,7 +471,7 @@ Parameters don't need tuning.
|
471 | 471 | epsilon = 1e-8
|
472 | 472 | end
|
473 | 473 |
|
474 |
| -init(o::NAdam, x::AbstractArray) = (zero(x), zero(x), o.beta) |
| 474 | +init(o::NAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) |
475 | 475 |
|
476 | 476 | function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
|
477 | 477 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
@@ -524,7 +524,7 @@ Adam optimiser.
|
524 | 524 | epsilon = 1e-16
|
525 | 525 | end
|
526 | 526 |
|
527 |
| -init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x), o.beta) |
| 527 | +init(o::AdaBelief, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) |
528 | 528 |
|
529 | 529 | function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
|
530 | 530 | η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
|
|
0 commit comments