Skip to content

Commit ee4c130

Browse files
bors[bot]mcabbott
andauthored
Merge #1661
1661: Deprecate `Flux.zeros` r=mcabbott a=mcabbott Seems like a footgun to have a function with the same name & almost the same function as Base's. It seems to have been used inconsistently, with some functions defining their own closure to be less confusing. This gives it a new name, `zeros32`, and uses it everywhere. Ditto `ones32`. Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
2 parents 1022f1c + 0263e30 commit ee4c130

File tree

7 files changed

+42
-37
lines changed

7 files changed

+42
-37
lines changed

src/deprecations.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,16 @@ function Base.getproperty(a::Dense, s::Symbol)
1818
end
1919
return getfield(a, s)
2020
end
21+
22+
function ones(dims...)
23+
Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones)
24+
end
25+
ones(T::Type, dims...) = Base.ones(T, dims...) # no need for a message
26+
27+
function zeros(dims...)
28+
Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :ones)
29+
end
30+
zeros(T::Type, dims...) = Base.zeros(T, dims...)
31+
32+
ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type"))
33+
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))

src/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ function Diagonal(sz::Integer...; initα = nothing, initβ = nothing)
180180
Base.depwarn("keyword initα is deprecated, please simply supply the desired vectors", :Diagonal)
181181
initα(sz...)
182182
else
183-
ones(sz...)
183+
ones32(sz...)
184184
end
185185
β = if initβ !== nothing
186186
Base.depwarn("keyword initβ is deprecated, please simply supply the desired vectors", :Diagonal)
187187
initβ(sz...)
188188
else
189-
zeros(sz...)
189+
zeros32(sz...)
190190
end
191191
Diagonal(α, β)
192192
end

src/layers/normalise.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ end
198198

199199
"""
200200
BatchNorm(channels::Integer, λ=identity;
201-
initβ=zeros, initγ=ones,
201+
initβ=zeros32, initγ=ones32,
202202
ϵ=1f-5, momentum= 0.1f0)
203203
204204
[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
@@ -246,15 +246,14 @@ mutable struct BatchNorm{F,V,N,W}
246246
end
247247

248248
function BatchNorm(chs::Int, λ=identity;
249-
initβ = i -> zeros(Float32, i),
250-
initγ = i -> ones(Float32, i),
249+
initβ=zeros32, initγ=ones32,
251250
affine=true, track_stats=true,
252251
ϵ=1f-5, momentum=0.1f0)
253252

254253
β = affine ? initβ(chs) : nothing
255254
γ = affine ? initγ(chs) : nothing
256-
μ = track_stats ? zeros(Float32, chs) : nothing
257-
σ² = track_stats ? ones(Float32, chs) : nothing
255+
μ = track_stats ? zeros32(chs) : nothing
256+
σ² = track_stats ? ones32(chs) : nothing
258257

259258
return BatchNorm(λ, β, γ,
260259
μ, σ², ϵ, momentum,
@@ -286,7 +285,7 @@ end
286285

287286
"""
288287
InstanceNorm(channels::Integer, λ=identity;
289-
initβ=zeros, initγ=ones,
288+
initβ=zeros32, initγ=ones32,
290289
affine=false, track_stats=false,
291290
ϵ=1f-5, momentum=0.1f0)
292291
@@ -323,15 +322,14 @@ mutable struct InstanceNorm{F,V,N,W}
323322
end
324323

325324
function InstanceNorm(chs::Int, λ=identity;
326-
initβ = i -> zeros(Float32, i),
327-
initγ = i -> ones(Float32, i),
325+
initβ=zeros32, initγ=ones32,
328326
affine=false, track_stats=false,
329327
ϵ=1f-5, momentum=0.1f0)
330328

331329
β = affine ? initβ(chs) : nothing
332330
γ = affine ? initγ(chs) : nothing
333-
μ = track_stats ? zeros(Float32, chs) : nothing
334-
σ² = track_stats ? ones(Float32, chs) : nothing
331+
μ = track_stats ? zeros32(chs) : nothing
332+
σ² = track_stats ? ones32(chs) : nothing
335333

336334
return InstanceNorm(λ, β, γ,
337335
μ, σ², ϵ, momentum,
@@ -363,8 +361,7 @@ end
363361

364362
"""
365363
GroupNorm(channels::Integer, G::Integer, λ=identity;
366-
initβ = (i) -> zeros(Float32, i),
367-
initγ = (i) -> ones(Float32, i),
364+
initβ=zeros32, initγ=ones32,
368365
affine=true, track_stats=false,
369366
ϵ=1f-5, momentum=0.1f0)
370367
@@ -406,17 +403,16 @@ end
406403
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
407404

408405
function GroupNorm(chs::Int, G::Int, λ=identity;
409-
initβ = (i) -> zeros(Float32, i),
410-
initγ = (i) -> ones(Float32, i),
406+
initβ=zeros32, initγ=ones32,
411407
affine=true, track_stats=false,
412-
ϵ=1f-5, momentum=0.1f0)
408+
ϵ=1f-5, momentum=0.1f0)
413409

414410
chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")
415411

416412
β = affine ? initβ(chs) : nothing
417413
γ = affine ? initγ(chs) : nothing
418-
μ = track_stats ? zeros(Float32, G) : nothing
419-
σ² = track_stats ? ones(Float32, G) : nothing
414+
μ = track_stats ? zeros32(G) : nothing
415+
σ² = track_stats ? ones32(G) : nothing
420416

421417
return GroupNorm(G, λ,
422418
β, γ,

src/layers/recurrent.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct RNNCell{F,A,V,S}
7777
state0::S
7878
end
7979

80-
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
80+
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
8181
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
8282

8383
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
@@ -127,8 +127,8 @@ end
127127

128128
function LSTMCell(in::Integer, out::Integer;
129129
init = glorot_uniform,
130-
initb = zeros,
131-
init_state = zeros)
130+
initb = zeros32,
131+
init_state = zeros32)
132132
cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1)))
133133
cell.b[gate(out, 2)] .= 1
134134
return cell
@@ -190,7 +190,7 @@ struct GRUCell{A,V,S}
190190
state0::S
191191
end
192192

193-
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
193+
GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
194194
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
195195

196196
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}

src/utils.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ julia> Flux.identity_init(3,3,2,2)
346346
```
347347
"""
348348
# Assume bias
349-
identity_init(cols; gain=1, shift=0) = zeros(Float32, cols)
349+
identity_init(cols; gain=1, shift=0) = zeros32(cols)
350350

351351
# Assume matrix multiplication
352352
identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)
@@ -355,7 +355,7 @@ identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain,
355355
function identity_init(dims...; gain=1, shift=0)
356356
nin, nout = dims[end-1], dims[end]
357357
centers = map(d -> cld(d, 2), dims[1:end-2])
358-
weights = zeros(Float32, dims)
358+
weights = zeros32(dims)
359359
for i in 1:min(nin,nout)
360360
weights[centers..., i, i] = gain
361361
end
@@ -366,12 +366,8 @@ identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs
366366
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
367367
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)
368368

369-
370-
ones(T::Type, dims...) = Base.ones(T, dims...)
371-
zeros(T::Type, dims...) = Base.zeros(T, dims...)
372-
373-
ones(dims...) = Base.ones(Float32, dims...)
374-
zeros(dims...) = Base.zeros(Float32, dims...)
369+
ones32(dims...) = Base.ones(Float32, dims...)
370+
zeros32(dims...) = Base.zeros(Float32, dims...)
375371

376372
"""
377373
create_bias(weights, bias, length)

test/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ import Flux: activations
166166
@test b3.bias isa Vector{Float16}
167167
@test size(b3(rand(4), rand(5))) == (3,)
168168

169-
b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros)
169+
b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros32)
170170
@test_skip b4.bias isa Vector{Float32}
171171

172172
@test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array

test/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,20 +349,20 @@ end
349349
import Flux: loadparams!
350350
pars(w, b) = [w, b]
351351
import Flux: loadparams!, Zeros
352-
pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))]
352+
pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))]
353353
pars(l) = pars(l.W, l.b)
354354
pararray(m) = mapreduce(pars, vcat, m)
355355
weights(m) = mapreduce(l -> [l.W], vcat, m)
356-
@testset "Bias type $bt" for bt in (Flux.zeros, nobias)
356+
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
357357
m = dm(bt)
358358
loadparams!(m, params(m))
359359
testdense(m, bt)
360360
end
361361

362362
@testset "$b1 to $b2" for (b1, b2, be) in (
363-
(Flux.zeros, Flux.ones, Flux.ones), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
364-
(Flux.ones, nobias, Flux.zeros), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
365-
(nobias, Flux.ones, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
363+
(Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
364+
(Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
365+
(nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
366366
)
367367
m1 = dm(b1)
368368
m2 = dm(b2)

0 commit comments

Comments
 (0)