Skip to content

Commit 6573f65

Browse files
committed
Use NNlib.conv_bias_act for Conv
1 parent 7b56813 commit 6573f65

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,13 @@ function (c::Conv)(x::AbstractArray)
162162
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
163163
σ = NNlib.fast_act(c.σ, x)
164164
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
165-
σ.(conv(x, c.weight, cdims) .+ b)
165+
_conv_bias_act(x, c.weight, cdims, b, σ)
166166
end
167167

168+
_conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ)
169+
_conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) =
170+
_conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ)
171+
168172
_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
169173
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))
170174

0 commit comments

Comments
 (0)