Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit e6b5376

Browse files
authored
Merge pull request #690 from AStupidBear/conv1d
fix 1d convolution
2 parents c38da71 + 4783f83 commit e6b5376

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

src/dnn/nnlib.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,25 @@ end
4141

4242
# Convolution
4343

44+
# Since CUDNN does not support 1D convolution, Conv in Flux will give a CUDNNError if the size is 1-dimensional.
45+
# We have to reshape the CuArray/PoolDims/DenseConvDims to 4D before feeding to CUDNN.
46+
fix1d(x) = x
47+
48+
fix1d(x::CuArray{T, 3}) where T = reshape(x, size(x, 1), 1, size(x, 2), size(x, 3))
49+
50+
fix1d(cdims::DenseConvDims{1,K,C_in,C_out,S,P,D,F}) where {K,C_in,C_out,S,P,D,F} =
51+
DenseConvDims{2,(K...,1),C_in,C_out,(S...,1),(P...,0,0),(D...,1),F}((cdims.I...,1))
52+
53+
fix1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D,F} =
54+
PoolDims{2,(K...,1),(S...,1),(P...,0,0),(D...,1)}((pdims.I..., 1), pdims.C_in)
55+
4456
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims;
4557
alpha=1, algo=0) where T<:CUDNNFloat
4658
if version() < v"6"
4759
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
4860
end
49-
50-
cudnnConvolutionForward(y, x, w, cdims, alpha=alpha, algo=algo)
61+
cudnnConvolutionForward(fix1d(y), fix1d(x), fix1d(w), fix1d(cdims), alpha=alpha, algo=algo)
62+
return y
5163
end
5264

5365
function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
@@ -56,7 +68,8 @@ function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
5668
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
5769
end
5870

59-
cudnnConvolutionBackwardFilter(dw, x, dy, cdims, alpha=alpha, algo=algo)
71+
cudnnConvolutionBackwardFilter(fix1d(dw), fix1d(x), fix1d(dy), fix1d(cdims), alpha=alpha, algo=algo)
72+
return dw
6073
end
6174

6275
function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, w::CuArray{T},
@@ -65,22 +78,23 @@ function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, w::CuArray{T},
6578
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
6679
end
6780

68-
cudnnConvolutionBackwardData(dx, w, dy, cdims, alpha=alpha, algo=algo)
81+
cudnnConvolutionBackwardData(fix1d(dx), fix1d(w), fix1d(dy), fix1d(cdims), alpha=alpha, algo=algo)
82+
return dx
6983
end
7084

7185
∇conv_bias!(db::CuArray{T}, dy::CuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat =
72-
cudnnConvolutionBackwardBias(db, dy, alpha=alpha, beta=beta)
86+
(cudnnConvolutionBackwardBias(fix1d(db), fix1d(dy), alpha=alpha, beta=beta); return db)
7387

7488
maxpool!(y::CuArray{T}, x::CuArray{T}, pdims::PoolDims) where T<:CUDNNFloat =
75-
cudnnPoolingForward(y, x, pdims; mode=0)
89+
(cudnnPoolingForward(fix1d(y), fix1d(x), fix1d(pdims); mode=0); return y)
7690

7791
∇maxpool!(dx::CuArray{T}, dy::CuArray{T}, y::CuArray{T}, x::CuArray{T},
7892
pdims::PoolDims) where T<:CUDNNFloat =
79-
cudnnPoolingBackward(dx, dy, x, y, pdims, mode=0)
93+
(cudnnPoolingBackward(fix1d(dx), fix1d(dy), fix1d(x), fix1d(y), fix1d(pdims), mode=0); return dx)
8094

8195
meanpool!(y::CuArray{T}, x::CuArray{T}, pdims::PoolDims) where T<:CUDNNFloat =
82-
cudnnPoolingForward(y, x, pdims, mode=1)
96+
(cudnnPoolingForward(fix1d(y), fix1d(x), fix1d(pdims), mode=1); return y)
8397

8498
∇meanpool!(dx::CuArray{T}, dy::CuArray{T}, y::CuArray{T}, x::CuArray{T},
8599
pdims::PoolDims) where T<:CUDNNFloat =
86-
cudnnPoolingBackward(dx, dy, x, y, pdims, mode=1)
100+
(cudnnPoolingBackward(fix1d(dx), fix1d(dy), fix1d(x), fix1d(y), fix1d(pdims), mode=1); return dx)

test/dnn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ else
2323
@test ∇conv_filter(a, c, cdims) collect(∇conv_filter(da, dc, cdims))
2424

2525
# Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs
26-
for num_spatial_dims in (2, 3)
26+
for num_spatial_dims in (1, 2, 3)
2727
# Initialize data we'll run our tests over
2828
C_in = 3
2929
C_out = 4

0 commit comments

Comments
 (0)