Skip to content

Commit c90e0b9

Browse files
committed
Use NNlib.conv_bias_act for Conv
1 parent 79dbbd6 commit c90e0b9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,13 @@ end
163163
function (c::Conv)(x::AbstractArray)
164164
σ, b = c.σ, reshape(c.bias, ntuple(_ -> 1, length(c.stride))..., :, 1)
165165
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
166-
σ.(conv(x, c.weight, cdims) .+ b)
166+
_conv_bias_act(x, c.weight, cdims, b, σ)
167167
end
168168

169+
_conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ)
170+
_conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) =
171+
_conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ)
172+
169173
_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
170174
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))
171175

0 commit comments

Comments
 (0)