Skip to content

Commit 4f87f2b

Browse files
committed
Fix layer init functions, kwargs was shadowed
There were a few functions of the form: ``` foo(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> foo(rng, dims...; kwargs...) ``` The intention was that you could do `foo(my_kw=42)` to get a callable, and then later call that callable with the kwarg already set. However, the `kwargs` variable of the lambda shadowed the initial `kwargs`. This is a fix.
1 parent 7e9a180 commit 4f87f2b

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
135135
end
136136

137137
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
138-
kaiming_uniform(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; kwargs...)
138+
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)
139139

140140
"""
141141
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)
@@ -172,7 +172,7 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
172172
end
173173

174174
kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
175-
kaiming_normal(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; kwargs...)
175+
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)
176176

177177
"""
178178
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
@@ -216,7 +216,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
216216
end
217217

218218
sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
219-
sparse_init(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; kwargs...)
219+
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)
220220

221221
ones(T::Type, dims...) = Base.ones(T, dims...)
222222
zeros(T::Type, dims...) = Base.zeros(T, dims...)

test/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,22 @@ end
119119
@test eltype(v) == Float32
120120
end
121121
end
122+
123+
@testset "partial_application" begin
124+
big = 1e9
125+
126+
partial_ku = kaiming_uniform(gain=big)
127+
@test maximum(partial_ku(8, 8)) > big / 2
128+
@test maximum(partial_ku(8, 8, gain=1)) < big / 2
129+
130+
partial_kn = kaiming_normal(gain=big)
131+
@test maximum(partial_kn(8, 8)) > big / 2
132+
@test maximum(partial_kn(8, 8, gain=1)) < big / 2
133+
134+
partial_si = sparse_init(sparsity=1)
135+
@test maximum(partial_si(8, 8)) == 0
136+
@test maximum(partial_si(8, 8, sparsity=0)) > 0
137+
end
122138
end
123139

124140
@testset "Params" begin

0 commit comments

Comments
 (0)