Skip to content

Commit 6824c6d

Browse files
nikopjNikolaJanjusevic
authored
cudnn complex convolution via gauss trick (#517)
* cudnn complex convolution via gauss trick * removed debug line, added gauss trick comment * fix typo in comment * conv doc updated, mixed real and complex cuda conv added * update complexconv tests * complex conv test updates * doc reference mention complex conv on CUDA and CPU --------- Co-authored-by: Nikola <npj226@nyu.edu> Co-authored-by: Janjusevic <npj226@a100-4002.cm.cluster>
1 parent ecdc95b commit 6824c6d

File tree

5 files changed

+187
-25
lines changed

5 files changed

+187
-25
lines changed

docs/src/reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ pad_zeros
7474

7575
## Convolution
7676

77-
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
77+
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
78+
`NNlib.conv` supports complex datatypes on CPU and CUDA devices.
7879

7980
!!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true).
8081
Therefore for every regular convolution (flipkernel=false)

ext/NNlibCUDACUDNNExt/conv.jl

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using cuDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
99
cudnnConvolutionBackwardBias
1010

1111
const CUDNNFloat = Union{Float16,Float32,Float64}
12+
const CUDNNComplexFloat = Union{ComplexF16,ComplexF32,ComplexF64}
1213

1314
function cudnnConvolutionDescriptorAndPaddedInput(cdims::DenseConvDims, x::DenseCuArray{T}) where T
1415
# The main purpose of this function is to catch asymmetric padding which cudnn does not support
@@ -49,12 +50,22 @@ function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}, pa
4950
convdims(NNlib.stride(cdims),size(x),1),
5051
convdims(NNlib.dilation(cdims),size(x),1),
5152
mode,
52-
cudnnDataType(T),
53+
cudnnDataType(real(T)),
5354
math_mode(),
5455
CUDNN_DEFAULT_REORDER,
5556
Cint(NNlib.groupcount(cdims)))
5657
end
5758

59+
@inline function _complex!(y::DenseCuArray{T1}, yr::DenseCuArray{T2}, yi::DenseCuArray{T2}; bias=zero(T1), alpha=one(T1), beta=zero(T1), σ=identity) where {T1 <: CUDNNComplexFloat, T2<:CUDNNFloat}
60+
# if y is from similar(), it may have NaNs, and beta*NaN will propagate.
61+
if beta != 0
62+
@. y = σ(alpha*(yr + im*yi) + bias + beta*y)
63+
else
64+
@. y = σ(alpha*(yr + im*yi) + bias)
65+
end
66+
return y
67+
end
68+
5869
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
5970
alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
6071
if cudnnversion() < v"6"
@@ -67,6 +78,43 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
6778
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
6879
end
6980

81+
# Complex convolution with Gauss's trick (1 complex mul === 3 real mul):
82+
# Consider x = xr + im*xi, y = yr + im*yi,
83+
# so x*y = (xr*yr - xi*yi) + im*(xr*yi + xi*yr).
84+
# Let a = xr*yr,
85+
# b = xi*yi,
86+
# c = (xr + xi)*(yr + yi) = xr*yr + xr*yi + xi*yr + xi*yi.
87+
# Then,
88+
# x*y = (a - b) + im*(c - a - b).
89+
# Convolution is linear so this multiplication trick translates to convolution.
90+
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
91+
alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
92+
xr, xi = reim(x)
93+
wr, wi = reim(w)
94+
a = conv!(similar(real(y)), xr, wr, cdims; algo=algo)
95+
b = conv!(similar(a), xi, wi, cdims; algo=algo)
96+
c = conv!(similar(a), xr + xi, wr + wi, cdims; algo=algo)
97+
return _complex!(y, a - b, c - a - b; alpha=alpha, beta=beta)
98+
end
99+
100+
# (xr + im*xi) * w = xr*w + im*(xi*w)
101+
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T1}, w::DenseCuArray{T2}, cdims::DenseConvDims;
102+
alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
103+
xr, xi = reim(x)
104+
yr = conv!(similar(real(y)), xr, w, cdims; algo=algo)
105+
yi = conv!(similar(yr), xi, w, cdims; algo=algo)
106+
return _complex!(y, yr, yi; alpha=alpha, beta=beta)
107+
end
108+
109+
# x * (wr + im*wi) = x*wr + im*(x*wi)
110+
function conv!(y::DenseCuArray{T1}, x::DenseCuArray{T2}, w::DenseCuArray{T1}, cdims::DenseConvDims;
111+
alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
112+
wr, wi = reim(w)
113+
yr = conv!(similar(real(y)), x, wr, cdims; algo=algo)
114+
yi = conv!(similar(yr), x, wi, cdims; algo=algo)
115+
return _complex!(y, yr, yi; alpha=alpha, beta=beta)
116+
end
117+
70118
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
71119
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
72120
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
@@ -86,6 +134,17 @@ function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{
86134
return y
87135
end
88136

137+
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
138+
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
139+
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
140+
xr, xi = reim(x)
141+
wr, wi = reim(w)
142+
a = conv!(similar(real(y)), xr, wr, cdims; alpha=1, beta=0, algo=algo)
143+
b = conv!(similar(a), xi, wi, cdims; alpha=1, beta=0, algo=algo)
144+
c = conv!(similar(a), xr + xi, wr + wi, cdims; alpha=1, beta=0, algo=algo)
145+
return _complex!(y, a - b, c - a - b; bias=bias, alpha=alpha, beta=beta, σ=σ)
146+
end
147+
89148
function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
90149
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
91150
if cudnnversion() < v"6"
@@ -104,6 +163,26 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
104163
return depad(dx)
105164
end
106165

166+
function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray{T},
167+
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
168+
dyr, dyi = reim(dy)
169+
wr, wi = reim(w)
170+
# note: w is conjugated, i.e. wi is negated below
171+
a = ∇conv_data!(similar(real(dx)), dyr, wr, cdims; alpha=1, beta=0, algo=algo)
172+
b = ∇conv_data!(similar(a), dyi, -wi, cdims; alpha=1, beta=0, algo=algo)
173+
c = ∇conv_data!(similar(a), dyr + dyi, wr - wi, cdims; alpha=1, beta=0, algo=algo)
174+
return _complex!(dx, a - b, c - a - b; alpha=alpha, beta=beta)
175+
end
176+
177+
# dx = (dyr + im*dyi)*w = dyr*w + im*(dyi*w)
178+
function ∇conv_data!(dx::DenseCuArray{T1}, dy::DenseCuArray{T1}, w::DenseCuArray{T2},
179+
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
180+
dyr, dyi = reim(dy)
181+
dxr = ∇conv_data!(similar(real(dx)), dyr, w, cdims; alpha=1, beta=0, algo=algo)
182+
dxi = ∇conv_data!(similar(dxr), dyi, w, cdims; alpha=1, beta=0, algo=algo)
183+
return _complex!(dx, dxr, dxi; alpha=alpha, beta=beta)
184+
end
185+
107186
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
108187
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
109188
if cudnnversion() < v"6"
@@ -122,9 +201,36 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
122201
return dw
123202
end
124203

204+
function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
205+
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where T<:CUDNNComplexFloat
206+
xr, xi = reim(x)
207+
dyr, dyi = reim(dy)
208+
# note: x is conjugated, i.e. xi is negated below
209+
a = ∇conv_filter!(similar(real(dw)), xr, dyr, cdims; alpha=1, beta=0, algo=algo)
210+
b = ∇conv_filter!(similar(a), -xi, dyi, cdims; alpha=1, beta=0, algo=algo)
211+
c = ∇conv_filter!(similar(a), xr - xi, dyr + dyi, cdims; alpha=1, beta=0, algo=algo)
212+
return _complex!(dw, a - b, c - a - b; alpha=alpha, beta=beta)
213+
end
214+
215+
# dw = x*(dyr + im*dyi) = x*dyr + im*(x*dyi)
216+
function ∇conv_filter!(dw::DenseCuArray{T1}, x::DenseCuArray{T2}, dy::DenseCuArray{T1},
217+
cdims::DenseConvDims; alpha=1, beta=0, algo=-1) where {T1<:CUDNNComplexFloat, T2<:CUDNNFloat}
218+
dyr, dyi = reim(dy)
219+
dwr = ∇conv_filter!(similar(real(dw)), x, dyr, cdims; alpha=1, beta=0, algo=algo)
220+
dwi = ∇conv_filter!(similar(dwr), x, dyi, cdims; alpha=1, beta=0, algo=algo)
221+
return _complex!(dw, dwr, dwi; alpha=alpha, beta=beta)
222+
end
223+
125224
function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNFloat
126225
alpha,beta = scalingParameter(T,alpha), scalingParameter(T,beta)
127226
bDesc, yDesc = cudnnTensorDescriptor.((db,dy))
128227
cudnnConvolutionBackwardBias(handle(), alpha, yDesc, dy, beta, bDesc, db)
129228
return db
130229
end
230+
231+
function ∇conv_bias!(db::DenseCuArray{T}, dy::DenseCuArray{T}; alpha=1, beta=0) where T<:CUDNNComplexFloat
232+
dyr, dyi = reim(dy)
233+
dbr = ∇conv_bias!(similar(real(db)), dyr; alpha=1, beta=0)
234+
dbi = ∇conv_bias!(similar(dbr), dyi; alpha=1, beta=0)
235+
return _complex!(db, dbr, dbi; alpha=alpha, beta=beta)
236+
end

src/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
conv(x, w; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1)
4646
4747
Apply convolution filter `w` to input `x`. `x` and `w` are 3d/4d/5d tensors
48-
in 1d/2d/3d convolutions respectively.
48+
in 1d/2d/3d convolutions respectively. `x` and `w` may have real or complex element types.
4949
"""
5050
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1, flipped = false, groups = 1) where {T, N}
5151
stride = expand(Val(N - 2), stride)

test/ext_cuda/conv.jl

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
11
using NNlib: DenseConvDims
22

33
@testset "convolution" begin
4-
a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1)
4+
@testset "$T" for T in (Float64, ComplexF64)
5+
a, b, c = rand(T, 10, 10, 3, 1), rand(T, 2, 2, 3, 4), rand(T, 9, 9, 4, 1)
56
da, db, dc = CuArray(a), CuArray(b), CuArray(c)
67
cdims = DenseConvDims(a, b)
78
@test NNlib.conv(a, b, cdims) collect(NNlib.conv(da, db, cdims))
89
@test ∇conv_data(c, b, cdims) collect(∇conv_data(dc, db, cdims))
910
@test ∇conv_filter(a, c, cdims) collect(∇conv_filter(da, dc, cdims))
1011

12+
if T <: Complex
13+
@testset "mixed real and complex" begin
14+
@test NNlib.conv(real(a), b, cdims) collect(NNlib.conv(real(da), db, cdims))
15+
@test NNlib.conv(a, real(b), cdims) collect(NNlib.conv(da, real(db), cdims))
16+
@test ∇conv_data(c, real(b), cdims) collect(∇conv_data(dc, real(db), cdims))
17+
@test ∇conv_filter(real(a), c, cdims) collect(∇conv_filter(real(da), dc, cdims))
18+
end
19+
end
20+
1121
# Test Conv Bias Activation
12-
bias = rand(Float64, 1, 1, 4, 1)
22+
bias = rand(T, 1, 1, 4, 1)
1323
dbias = CuArray(bias)
14-
@test conv_bias_act(a, b, cdims, bias, NNlib.relu) collect(conv_bias_act(da, db, cdims, dbias, NNlib.relu))
24+
act = T <: Complex ? abs2 : NNlib.relu
25+
@test conv_bias_act(a, b, cdims, bias, act) collect(conv_bias_act(da, db, cdims, dbias, act))
1526
@test conv_bias_act(a, b, cdims, bias, identity) collect(conv_bias_act(da, db, cdims, dbias, identity))
1627

1728
# Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs
@@ -26,16 +37,20 @@ using NNlib: DenseConvDims
2637
C_out = 4
2738
batch_size = 1
2839

29-
for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)
40+
# we use this activation for the gpu tests
41+
# as we can't take gradients of complex quantities
42+
act = T <: Complex ? x-> abs2(x) : identity
43+
@testset "groups=$groups, num_spatial_dims=$num_spatial_dims" for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)
3044
# Make `C_in = C_out` when using grouped convolution.
3145
C_in = groups == 1 ? C_in_ : C_out
3246
# Initialize data we'll run our tests over
33-
x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size)
34-
w = rand(Float64, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)
47+
x = rand(T, fill(8, num_spatial_dims)..., C_in, batch_size)
48+
w = rand(T, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)
3549

36-
for opts in options
50+
@testset "opts #$i" for (i,opts) in enumerate(options)
3751
opts[:groups] = groups
38-
52+
53+
3954
if :padding in keys(opts)
4055
padding = opts[:padding]
4156
if 1 < length(padding) && length(padding) != 2num_spatial_dims
@@ -47,18 +62,56 @@ using NNlib: DenseConvDims
4762
y = NNlib.conv(x, w, cdims)
4863

4964
# Test that basic convolution is equivalent across GPU/CPU
50-
gputest((x, w) -> NNlib.conv(x, w, cdims), x, w)
51-
gputest((y, w) -> NNlib.∇conv_data(y, w, cdims), y, w)
52-
gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims), x, y, checkgrad=false) # TODO fix grad
65+
@testset "cpu==gpu" begin
66+
@testset "conv" begin
67+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, w)
68+
if T <: Complex
69+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), real(x), w)
70+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims)), x, real(w))
71+
end
72+
end
73+
@testset "∇conv_data" begin
74+
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, w)
75+
if T <: Complex
76+
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims)), y, real(w))
77+
end
78+
end
79+
@testset "∇conv_filter" begin
80+
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), x, y)
81+
if T <: Complex
82+
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims)), real(x), y)
83+
end
84+
end
85+
end
5386

5487
# Scaling factors
55-
gputest((x, w) -> NNlib.conv(x, w, cdims; alpha=2.0), x, w, checkgrad=false) # TODO
56-
gputest((y, w) -> NNlib.∇conv_data(y, w, cdims; alpha=2.0), y, w, checkgrad=false) # TODO
57-
gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims; alpha=2.0), x, y, checkgrad=false) # TODO
88+
@testset "scale-alpha" begin
89+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, w, checkgrad=false) # TODO
90+
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, w, checkgrad=false) # TODO
91+
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), x, y, checkgrad=false) # TODO
92+
93+
if T <: Complex
94+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), real(x), w, checkgrad=false)
95+
gputest((x, w) -> act.(NNlib.conv(x, w, cdims; alpha=T(2.0))), x, real(w), checkgrad=false) # TODO
96+
gputest((y, w) -> act.(NNlib.∇conv_data(y, w, cdims; alpha=T(2.0))), y, real(w), checkgrad=false) # TODO
97+
gputest((x, y) -> act.(NNlib.∇conv_filter(x, y, cdims; alpha=T(2.0))), real(x), y, checkgrad=false) # TODO
98+
end
99+
end
100+
101+
@testset "scale-beta" begin
102+
gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, w, checkgrad=false, broken=false)
103+
gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, x, y, checkgrad=false, broken=false)
104+
gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, w, checkgrad=false, broken=true)
105+
106+
if T <: Complex
107+
gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, real(x), w, checkgrad=false)
108+
gputest((y, x, w) -> act.(NNlib.conv!(copy(y), x, w, cdims; beta=T(2.0))), y, x, real(w), checkgrad=false)
109+
gputest((x, y, w) -> act.(NNlib.∇conv_data!(copy(x), y, w, cdims; beta=T(2.0))), x, y, real(w), checkgrad=false)
110+
gputest((w, x, y) -> act.(NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=T(2.0))), w, real(x), y, checkgrad=false)
111+
end
112+
end
58113

59-
gputest((y, x, w) -> NNlib.conv!(copy(y), x, w, cdims; beta=2.0), y, x, w, checkgrad=false) # TODO
60-
# @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO
61-
gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO
62114
end
63115
end
64116
end
117+
end

test/ext_cuda/test_utils.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
function gputest(f, xs...; checkgrad=true, atol=1e-10, kws...)
1+
function gputest(f, xs...; checkgrad=true, rtol=1e-7, atol=1e-10, broken=false, broken_grad=false, kws...)
22
cpu_in = xs
33
gpu_in = CuArray.(xs)
44

55
cpu_out = f(cpu_in...; kws...)
66
gpu_out = f(gpu_in...; kws...)
7-
@test collect(cpu_out) collect(gpu_out)
7+
@test collect(cpu_out) collect(gpu_out) rtol=rtol atol=atol broken=broken
88

99
if checkgrad
10-
cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)
11-
gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)
10+
# use mean instead of sum to prevent error accumulation (for larger
11+
# tensors) which causes error to go above atol
12+
cpu_grad = gradient((x...) -> mean(f(x...; kws...)), cpu_in...)
13+
gpu_grad = gradient((x...) -> mean(f(x...; kws...)), gpu_in...)
1214
for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)
1315
if cpu_g === nothing
1416
@test gpu_g === nothing
1517
else
16-
@test collect(cpu_g) collect(gpu_g) atol=atol
18+
@test collect(cpu_g) collect(gpu_g) rtol=rtol atol=atol broken=broken_grad
1719
end
1820
end
1921
end

0 commit comments

Comments
 (0)