Skip to content

Commit c7ed5fe

Browse files
authored
Merge pull request #2081 from Saransh-cpp/create_bias
Back to create_bias
2 parents b08cb67 + 2a1e4d2 commit c7ed5fe

File tree

3 files changed

+9
-12
lines changed

3 files changed

+9
-12
lines changed

src/layers/basic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct Dense{F, M<:AbstractMatrix, B}
155155
bias::B
156156
σ::F
157157
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
158-
b = _create_bias(W, bias, size(W,1))
158+
b = create_bias(W, bias, size(W,1))
159159
new{F,M,typeof(b)}(W, b, σ)
160160
end
161161
end
@@ -228,7 +228,7 @@ struct Scale{F, A<:AbstractArray, B}
228228
bias::B
229229
σ::F
230230
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
231-
b = _create_bias(scale, bias, size(scale)...)
231+
b = create_bias(scale, bias, size(scale)...)
232232
new{F, A, typeof(b)}(scale, b, σ)
233233
end
234234
end
@@ -403,7 +403,7 @@ struct Bilinear{F,A,B}
403403
σ::F
404404
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
405405
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
406-
b = _create_bias(W, bias, size(W,1))
406+
b = create_bias(W, bias, size(W,1))
407407
new{F,A,typeof(b)}(W, b, σ)
408408
end
409409
end

src/layers/conv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
156156
stride = expand(Val(N-2), stride)
157157
dilation = expand(Val(N-2), dilation)
158158
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
159-
bias = _create_bias(w, b, size(w, N))
159+
bias = create_bias(w, b, size(w, N))
160160
return Conv(σ, w, bias, stride, pad, dilation, groups)
161161
end
162162

@@ -293,7 +293,7 @@ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
293293
stride = expand(Val(N-2), stride)
294294
dilation = expand(Val(N-2), dilation)
295295
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
296-
b = _create_bias(w, bias, size(w, N-1) * groups)
296+
b = create_bias(w, bias, size(w, N-1) * groups)
297297
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
298298
end
299299

@@ -441,7 +441,7 @@ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
441441
stride = expand(Val(N-2), stride)
442442
dilation = expand(Val(N-2), dilation)
443443
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
444-
b = _create_bias(w, bias, size(w, N))
444+
b = create_bias(w, bias, size(w, N))
445445
return CrossCor(σ, w, b, stride, pad, dilation)
446446
end
447447

src/utils.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...)
504504
randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...)
505505

506506
"""
507-
_create_bias(weights, bias, size...)
507+
create_bias(weights, bias, size...)
508508
509509
Return a bias parameter for a layer, based on the value given
510510
to the constructor's keyword `bias=bias`.
@@ -514,17 +514,14 @@ to the constructor's keyword `bias=bias`.
514514
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
515515
It does not at present correct the `eltype` to match that of `weights`.
516516
"""
517-
function _create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
517+
function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
518518
bias ? fill!(similar(weights, dims...), 0) : false
519519
end
520-
function _create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
520+
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
521521
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
522522
bias
523523
end
524524

525-
# TODO figure out whether we want to document or deprecate this
526-
const create_bias = _create_bias
527-
528525

529526
# Other
530527

0 commit comments

Comments
 (0)