Skip to content

Commit 4b87ec0

Browse files
authored
Implement reverse lookup (Ptr->Tuple) for CUDNN descriptors. (#1948)
1 parent dc16d92 commit 4b87ec0

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

lib/cudnn/src/convolution.jl

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,28 @@ function cudnnConvolutionForwardAD(w, x, bias, z; y, activation, convDesc, wDesc
120120
return y
121121
end
122122

123+
function cudnnGetConvolutionDescriptor(d::cudnnConvolutionDescriptor)
124+
# we don't know the dimension of the convolution, so we start by
125+
# allocating the maximum size it can be.
126+
nbDimsRequested = CUDNN_DIM_MAX - 2
127+
# later, here we get the actual dimensionality of the convolution
128+
arrlen = Ref{Cint}(nbDimsRequested)
129+
padding = Array{Cint}(undef, nbDimsRequested)
130+
stride = Array{Cint}(undef, nbDimsRequested)
131+
dilation = Array{Cint}(undef, nbDimsRequested)
132+
mode = Ref{cuDNN.cudnnConvolutionMode_t}(CUDNN_CONVOLUTION)
133+
dataType = Ref{cuDNN.cudnnDataType_t}(cuDNN.CUDNN_DATA_FLOAT)
134+
135+
cudnnGetConvolutionNdDescriptor(d, nbDimsRequested, arrlen, padding, stride, dilation,
136+
mode, dataType)
137+
T = juliaDataType(dataType[])
138+
SZ = arrlen[]
139+
P = (padding[1:SZ]..., )
140+
S = (stride[1:SZ]..., )
141+
D = (dilation[1:SZ]..., )
142+
return T, mode[], SZ, P, S, D
143+
end
144+
123145
# Helper for cudnnConvolutionDescriptor
124146
function cudnnSetConvolutionDescriptor(
125147
ptr::cudnnConvolutionDescriptor_t,
@@ -179,9 +201,15 @@ const cudnnConvolutionFwdAlgoPerfCacheLock = ReentrantLock()
179201
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
180202
"""
181203
function cudnnConvolutionFwdAlgoPerf(xDesc, x, wDesc, w, convDesc, yDesc, y, biasDesc, activation, allocateTmpBuf=true)
182-
key = (xDesc, wDesc, convDesc, biasDesc, activation)
204+
xDesc_native = cudnnGetTensorDescriptor(xDesc)
205+
wDesc_native = cudnnGetFilterDescriptor(wDesc)
206+
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)
207+
biasDesc_native = (isnothing(biasDesc) ? nothing
208+
: cudnnGetTensorDescriptor(biasDesc))
209+
210+
key = (xDesc_native, wDesc_native, convDesc_native, biasDesc, activation)
183211
val = lock(cudnnConvolutionFwdAlgoPerfCacheLock) do
184-
get(cudnnConvolutionFwdAlgoPerfCache, key, nothing)
212+
get(cudnnConvolutionFwdAlgoPerfCache, key, nothing)
185213
end
186214
if val === nothing
187215
requestedAlgoCount = Int(CUDNN_CONVOLUTION_FWD_ALGO_COUNT)
@@ -210,7 +238,11 @@ const cudnnConvolutionBwdDataAlgoPerfCacheLock = ReentrantLock()
210238
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
211239
"""
212240
function cudnnConvolutionBwdDataAlgoPerf(wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, allocateTmpBuf=true)
213-
key = (wDesc, dyDesc, convDesc)
241+
wDesc_native = cudnnGetFilterDescriptor(wDesc)
242+
dyDesc_native = cudnnGetTensorDescriptor(dyDesc)
243+
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)
244+
245+
key = (wDesc_native, dyDesc_native, convDesc_native)
214246
val = lock(cudnnConvolutionBwdDataAlgoPerfCacheLock) do
215247
get(cudnnConvolutionBwdDataAlgoPerfCache, key, nothing)
216248
end
@@ -241,7 +273,11 @@ const cudnnConvolutionBwdFilterAlgoPerfCacheLock = ReentrantLock()
241273
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
242274
"""
243275
function cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, dyDesc, dy, convDesc, dwDesc, dw, allocateTmpBuf=true)
244-
key = (xDesc, dyDesc, convDesc)
276+
xDesc_native = cudnnGetTensorDescriptor(xDesc)
277+
dyDesc_native = cudnnGetTensorDescriptor(dyDesc)
278+
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)
279+
280+
key = (xDesc_native, dyDesc_native, convDesc_native)
245281
val = lock(cudnnConvolutionBwdFilterAlgoPerfCacheLock) do
246282
get(cudnnConvolutionBwdFilterAlgoPerfCache, (xDesc, dyDesc, convDesc), nothing)
247283
end

0 commit comments

Comments
 (0)