Skip to content

Commit b917a32

Browse files
bors[bot]atiyo
andauthored
Merge #1454
1454: Add sparse initialization r=CarloLucibello a=atiyo Add sparse initialization, documentation and tests. Trim whitespace in editted files. This PR is intended to address one of the outstanding points in bringing Flux to parity with PyTorch's features so it partially addresses #1431 and fully addresses #1450. The implementation follows the method given in [PyTorch implementation](https://pytorch.org/docs/stable/_modules/torch/nn/init.html#sparse_): a normally-distributed array is created, then a fixed proportion of randomly chosen row-indices is zeroed out for every column. Like the PyTorch version, it is restricted to 2-d Arrays. ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: atiyo <atiyo@users.noreply.github.com> Co-authored-by: Atiyo Ghosh <atiyo@users.noreply.github.com>
2 parents 8dfe4fa + a31ddf8 commit b917a32

File tree

4 files changed

+91
-17
lines changed

4 files changed

+91
-17
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* Excise datasets in favour of other providers in the julia ecosystem.
77
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
88
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379)).
9+
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
910
* Other new features and bug fixes (see GitHub releases page)
1011

1112
## v0.11.2

docs/src/utilities.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Flux.glorot_uniform
3636
Flux.glorot_normal
3737
Flux.kaiming_uniform
3838
Flux.kaiming_normal
39+
Flux.sparse_init
3940
```
4041

4142
## Model Building

src/utils.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ julia> Flux.glorot_uniform(2, 3)
5656
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
5757
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
5858
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
59+
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
5960
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)
6061
6162
# References
@@ -88,6 +89,7 @@ julia> Flux.glorot_normal(3, 2)
8889
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
8990
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
9091
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
92+
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
9193
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)
9294
9395
# References
@@ -120,6 +122,7 @@ julia> Flux.kaiming_uniform(3, 2)
120122
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
121123
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
122124
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
125+
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
123126
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)
124127
125128
# References
@@ -156,6 +159,7 @@ julia> Flux.kaiming_normal(3, 2)
156159
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
157160
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
158161
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
162+
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
159163
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)
160164
161165
# References
@@ -170,14 +174,58 @@ end
170174
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
171175
kaiming_normal(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; kwargs...)
172176

177+
"""
178+
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
179+
180+
Return an `Array` of size `dims` where each column contains a fixed fraction of
181+
zero elements given by `sparsity`. Non-zero elements are normally distributed
182+
with a mean of zero and standard deviation `std`.
183+
184+
This method is described in [1].
185+
186+
# Examples
187+
```jldoctest; setup = :(using Random; Random.seed!(0))
188+
julia> Flux.sparse_init(3, 2, sparsity=0.1)
189+
3×2 Array{Float32,2}:
190+
0.00828413 0.0
191+
-0.00353007 0.00297336
192+
0.0 0.00586617
193+
```
194+
195+
# See also
196+
197+
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
198+
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
199+
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
200+
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
201+
202+
# References
203+
204+
[1] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010.
205+
"""
206+
function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
207+
if length(dims) != 2
208+
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
209+
end
210+
rows, cols = dims
211+
prop_zero = min(1.0, sparsity)
212+
num_zeros = ceil(Integer, prop_zero * rows)
213+
sparse_array = randn(rng, Float32, dims...) .* Float32(std)
214+
sparse_array[1:num_zeros, :] .= 0f0
215+
return mapslices(shuffle, sparse_array, dims=1)
216+
end
217+
218+
sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
219+
sparse_init(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; kwargs...)
220+
173221
ones(T::Type, dims...) = Base.ones(T, dims...)
174222
zeros(T::Type, dims...) = Base.zeros(T, dims...)
175223

176224
ones(dims...) = Base.ones(Float32, dims...)
177225
zeros(dims...) = Base.zeros(Float32, dims...)
178226

179227
"""
180-
create_bias(shallcreate::Bool, iftrue, dims...)
228+
create_bias(shallcreate::Bool, iftrue, dims...)
181229
create_bias(x, ::Any...)
182230
183231
Return a bias parameter for a layer.
@@ -188,7 +236,7 @@ Essentially handles the allowed input options for the `bias` keyword:
188236
If not a boolean, return self to handle the case of bias=somearray.
189237
"""
190238
create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros()
191-
create_bias(x, ::Any...) = x
239+
create_bias(x, ::Any...) = x
192240

193241
"""
194242
unsqueeze(xs, dim)

test/utils.jl

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Flux
2-
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, stack, unstack, Zeros
2+
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, sparse_init, stack, unstack, Zeros
33
using StatsBase: var, std
44
using Random
55
using Test
@@ -95,6 +95,30 @@ end
9595
@test eltype(v) == Float32
9696
end
9797
end
98+
99+
@testset "sparse_init" begin
100+
# sparse_init should yield an error for non 2-d dimensions
101+
# sparse_init should yield no zero elements if sparsity < 0
102+
# sparse_init should yield all zero elements if sparsity > 1
103+
# sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for other sparsity values
104+
# sparse_init should yield a kernel in its non-zero elements consistent with the std parameter
105+
106+
@test_throws ArgumentError sparse_init(100, 100, 100, sparsity=0.1)
107+
v = sparse_init(100, 100, sparsity=-0.1)
108+
@test sum(v .== 0) == 0
109+
@test eltype(v) == Float32
110+
v = sparse_init(100, 100, sparsity=1.1)
111+
@test sum(v .== 0) == length(v)
112+
@test eltype(v) == Float32
113+
114+
for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)]
115+
expected_zeros = ceil(Integer, n_in * sparsity)
116+
v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ)
117+
@test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out])
118+
@test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ
119+
@test eltype(v) == Float32
120+
end
121+
end
98122
end
99123

100124
@testset "Params" begin
@@ -141,22 +165,22 @@ end
141165

142166
@testset "Explicit" begin
143167
gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...)
144-
g = gfun(o, z)
168+
g = gfun(o, z)
145169
@test gfun(o, Z) == (g[1], nothing)
146170

147-
g = gfun(z, o)
171+
g = gfun(z, o)
148172
@test gfun(Z, o) == (nothing, g[2])
149173
end
150174

151175
@testset "Implicit" begin
152176
gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args)))
153-
g = gfun(o, z)
177+
g = gfun(o, z)
154178

155179
gres = gfun(o, Z)
156180
@test gres[o] == g[o]
157181
@test Z gres.params
158182

159-
g = gfun(z, o)
183+
g = gfun(z, o)
160184
gres = gfun(Z, o)
161185
@test gres[o] == g[o]
162186
@test Z gres.params
@@ -170,14 +194,14 @@ end
170194

171195
@testset "Explicit" begin
172196
gfun(args...) = gradient((x, y) -> sum(x ./ y), args...)
173-
g = gfun(z, o)
197+
g = gfun(z, o)
174198
@test gfun(Z, o) == (nothing, g[2])
175199
end
176200

177201
@testset "Implicit" begin
178202
gfun(x,y) = gradient(() -> sum(x ./ y), params([x,y]))
179203

180-
g = gfun(z, o)
204+
g = gfun(z, o)
181205
gres = gfun(Z, o)
182206
@test gres[o] == g[o]
183207
@test Z gres.params
@@ -193,21 +217,21 @@ end
193217
@testset "Explicit" begin
194218
gfun(args...) = gradient((x, y) -> sum(op(x,y)), args...)
195219

196-
g = gfun(o, z)
220+
g = gfun(o, z)
197221
@test gfun(o, Z) == (g[1], nothing)
198222

199-
g = gfun(z, o)
223+
g = gfun(z, o)
200224
@test gfun(Z, o) == (nothing, g[2])
201225
end
202226

203227
@testset "Implicit" begin
204228
gfun(args...) = gradient(() -> sum(op(args...)), params(collect(args)))
205-
g = gfun(o, z)
229+
g = gfun(o, z)
206230
gres = gfun(o, Z)
207231
@test gres[o] == g[o]
208232
@test Z gres.params
209233

210-
g = gfun(z, o)
234+
g = gfun(z, o)
211235
gres = gfun(Z, o)
212236
@test gres[o] == g[o]
213237
@test Z gres.params
@@ -225,7 +249,7 @@ end
225249

226250
@testset "Param remapping" begin
227251
ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...)
228-
dl(nin, nout, bias) = Dense(ls(nin, nout), bias(nout))
252+
dl(nin, nout, bias) = Dense(ls(nin, nout), bias(nout))
229253
dm(bias) = Chain(
230254
dl(3, 5, bias),
231255
dl(5, 4, bias),
@@ -239,10 +263,10 @@ end
239263
@test typeof(l1.b) === typeof(l2.b)
240264
end
241265

242-
@testset "loadparams!" begin
266+
@testset "loadparams!" begin
243267
import Flux: loadparams!
244268
pars(w, b::Zeros) = [w, zeros(size(w,2))]
245-
pars(w, b) = [w, b]
269+
pars(w, b) = [w, b]
246270
pars(l) = pars(l.W, l.b)
247271
pararray(m) = mapreduce(pars, vcat, m)
248272
weights(m) = mapreduce(l -> [l.W], vcat, m)
@@ -285,4 +309,4 @@ end
285309
@test c[1].testing
286310
trainmode!(c)
287311
@test !c[1].testing
288-
end
312+
end

0 commit comments

Comments
 (0)