Skip to content

Commit 1bdbb86

Browse files
authored
CUFFT: Support Float16 by switching to Xt APIs. (#2430)
1 parent 9654745 commit 1bdbb86

File tree

10 files changed

+744
-481
lines changed

10 files changed

+744
-481
lines changed

lib/cufft/CUFFT.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using ..APIUtils
55
using ..CUDA_Runtime
66

77
using ..CUDA
8-
using ..CUDA: CUstream, cuComplex, cuDoubleComplex, libraryPropertyType
8+
using ..CUDA: CUstream, cuComplex, cuDoubleComplex, cudaDataType, libraryPropertyType
99
using ..CUDA: unsafe_free!, retry_reclaim, initialize_context
1010

1111
using CEnum: @cenum

lib/cufft/fft.jl

Lines changed: 102 additions & 226 deletions
Large diffs are not rendered by default.

lib/cufft/libcufft.jl

Lines changed: 286 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# This file is automatically generated. Do not edit!
2-
# To re-generated, execute res/wrap.jl
2+
# To re-generate, execute res/wrap/wrap.jl
33

44
using CEnum
55

@@ -25,6 +25,42 @@ end
2525
end
2626
end
2727

28+
@cenum cudaXtCopyType_t::UInt32 begin
29+
LIB_XT_COPY_HOST_TO_DEVICE = 0
30+
LIB_XT_COPY_DEVICE_TO_HOST = 1
31+
LIB_XT_COPY_DEVICE_TO_DEVICE = 2
32+
end
33+
34+
const cudaLibXtCopyType = cudaXtCopyType_t
35+
36+
@cenum libFormat_t::UInt32 begin
37+
LIB_FORMAT_CUFFT = 0
38+
LIB_FORMAT_UNDEFINED = 1
39+
end
40+
41+
const libFormat = libFormat_t
42+
43+
struct cudaXtDesc_t
44+
version::Cint
45+
nGPUs::Cint
46+
GPUs::NTuple{64,Cint}
47+
data::NTuple{64,Ptr{Cvoid}}
48+
size::NTuple{64,Csize_t}
49+
cudaXtState::Ptr{Cvoid}
50+
end
51+
52+
const cudaXtDesc = cudaXtDesc_t
53+
54+
struct cudaLibXtDesc_t
55+
version::Cint
56+
descriptor::Ptr{cudaXtDesc}
57+
library::libFormat
58+
subFormat::Cint
59+
libDescriptor::Ptr{Cvoid}
60+
end
61+
62+
const cudaLibXtDesc = cudaLibXtDesc_t
63+
2864
@cenum cufftResult_t::UInt32 begin
2965
CUFFT_SUCCESS = 0
3066
CUFFT_INVALID_PLAN = 1
@@ -321,6 +357,255 @@ end
321357
property::cufftProperty)::cufftResult
322358
end
323359

360+
@cenum cufftXtSubFormat_t::UInt32 begin
361+
CUFFT_XT_FORMAT_INPUT = 0
362+
CUFFT_XT_FORMAT_OUTPUT = 1
363+
CUFFT_XT_FORMAT_INPLACE = 2
364+
CUFFT_XT_FORMAT_INPLACE_SHUFFLED = 3
365+
CUFFT_XT_FORMAT_1D_INPUT_SHUFFLED = 4
366+
CUFFT_XT_FORMAT_DISTRIBUTED_INPUT = 5
367+
CUFFT_XT_FORMAT_DISTRIBUTED_OUTPUT = 6
368+
CUFFT_FORMAT_UNDEFINED = 7
369+
end
370+
371+
const cufftXtSubFormat = cufftXtSubFormat_t
372+
373+
@cenum cufftXtCopyType_t::UInt32 begin
374+
CUFFT_COPY_HOST_TO_DEVICE = 0
375+
CUFFT_COPY_DEVICE_TO_HOST = 1
376+
CUFFT_COPY_DEVICE_TO_DEVICE = 2
377+
CUFFT_COPY_UNDEFINED = 3
378+
end
379+
380+
const cufftXtCopyType = cufftXtCopyType_t
381+
382+
@cenum cufftXtQueryType_t::UInt32 begin
383+
CUFFT_QUERY_1D_FACTORS = 0
384+
CUFFT_QUERY_UNDEFINED = 1
385+
end
386+
387+
const cufftXtQueryType = cufftXtQueryType_t
388+
389+
struct cufftXt1dFactors_t
390+
size::Clonglong
391+
stringCount::Clonglong
392+
stringLength::Clonglong
393+
substringLength::Clonglong
394+
factor1::Clonglong
395+
factor2::Clonglong
396+
stringMask::Clonglong
397+
substringMask::Clonglong
398+
factor1Mask::Clonglong
399+
factor2Mask::Clonglong
400+
stringShift::Cint
401+
substringShift::Cint
402+
factor1Shift::Cint
403+
factor2Shift::Cint
404+
end
405+
406+
const cufftXt1dFactors = cufftXt1dFactors_t
407+
408+
@cenum cufftXtWorkAreaPolicy_t::UInt32 begin
409+
CUFFT_WORKAREA_MINIMAL = 0
410+
CUFFT_WORKAREA_USER = 1
411+
CUFFT_WORKAREA_PERFORMANCE = 2
412+
end
413+
414+
const cufftXtWorkAreaPolicy = cufftXtWorkAreaPolicy_t
415+
416+
@checked function cufftXtSetGPUs(handle, nGPUs, whichGPUs)
417+
initialize_context()
418+
@gcsafe_ccall libcufft.cufftXtSetGPUs(handle::cufftHandle, nGPUs::Cint,
419+
whichGPUs::Ptr{Cint})::cufftResult
420+
end
421+
422+
@checked function cufftXtMalloc(plan, descriptor, format)
423+
initialize_context()
424+
@gcsafe_ccall libcufft.cufftXtMalloc(plan::cufftHandle,
425+
descriptor::Ptr{Ptr{cudaLibXtDesc}},
426+
format::cufftXtSubFormat)::cufftResult
427+
end
428+
429+
@checked function cufftXtMemcpy(plan, dstPointer, srcPointer, type)
430+
initialize_context()
431+
@gcsafe_ccall libcufft.cufftXtMemcpy(plan::cufftHandle, dstPointer::CuPtr{Cvoid},
432+
srcPointer::CuPtr{Cvoid},
433+
type::cufftXtCopyType)::cufftResult
434+
end
435+
436+
@checked function cufftXtFree(descriptor)
437+
initialize_context()
438+
@gcsafe_ccall libcufft.cufftXtFree(descriptor::Ptr{cudaLibXtDesc})::cufftResult
439+
end
440+
441+
@checked function cufftXtSetWorkArea(plan, workArea)
442+
initialize_context()
443+
@gcsafe_ccall libcufft.cufftXtSetWorkArea(plan::cufftHandle,
444+
workArea::Ptr{Ptr{Cvoid}})::cufftResult
445+
end
446+
447+
@checked function cufftXtExecDescriptorC2C(plan, input, output, direction)
448+
initialize_context()
449+
@gcsafe_ccall libcufft.cufftXtExecDescriptorC2C(plan::cufftHandle,
450+
input::Ptr{cudaLibXtDesc},
451+
output::Ptr{cudaLibXtDesc},
452+
direction::Cint)::cufftResult
453+
end
454+
455+
@checked function cufftXtExecDescriptorR2C(plan, input, output)
456+
initialize_context()
457+
@gcsafe_ccall libcufft.cufftXtExecDescriptorR2C(plan::cufftHandle,
458+
input::Ptr{cudaLibXtDesc},
459+
output::Ptr{cudaLibXtDesc})::cufftResult
460+
end
461+
462+
@checked function cufftXtExecDescriptorC2R(plan, input, output)
463+
initialize_context()
464+
@gcsafe_ccall libcufft.cufftXtExecDescriptorC2R(plan::cufftHandle,
465+
input::Ptr{cudaLibXtDesc},
466+
output::Ptr{cudaLibXtDesc})::cufftResult
467+
end
468+
469+
@checked function cufftXtExecDescriptorZ2Z(plan, input, output, direction)
470+
initialize_context()
471+
@gcsafe_ccall libcufft.cufftXtExecDescriptorZ2Z(plan::cufftHandle,
472+
input::Ptr{cudaLibXtDesc},
473+
output::Ptr{cudaLibXtDesc},
474+
direction::Cint)::cufftResult
475+
end
476+
477+
@checked function cufftXtExecDescriptorD2Z(plan, input, output)
478+
initialize_context()
479+
@gcsafe_ccall libcufft.cufftXtExecDescriptorD2Z(plan::cufftHandle,
480+
input::Ptr{cudaLibXtDesc},
481+
output::Ptr{cudaLibXtDesc})::cufftResult
482+
end
483+
484+
@checked function cufftXtExecDescriptorZ2D(plan, input, output)
485+
initialize_context()
486+
@gcsafe_ccall libcufft.cufftXtExecDescriptorZ2D(plan::cufftHandle,
487+
input::Ptr{cudaLibXtDesc},
488+
output::Ptr{cudaLibXtDesc})::cufftResult
489+
end
490+
491+
@checked function cufftXtQueryPlan(plan, queryStruct, queryType)
492+
initialize_context()
493+
@gcsafe_ccall libcufft.cufftXtQueryPlan(plan::cufftHandle, queryStruct::CuPtr{Cvoid},
494+
queryType::cufftXtQueryType)::cufftResult
495+
end
496+
497+
@cenum cufftXtCallbackType_t::UInt32 begin
498+
CUFFT_CB_LD_COMPLEX = 0
499+
CUFFT_CB_LD_COMPLEX_DOUBLE = 1
500+
CUFFT_CB_LD_REAL = 2
501+
CUFFT_CB_LD_REAL_DOUBLE = 3
502+
CUFFT_CB_ST_COMPLEX = 4
503+
CUFFT_CB_ST_COMPLEX_DOUBLE = 5
504+
CUFFT_CB_ST_REAL = 6
505+
CUFFT_CB_ST_REAL_DOUBLE = 7
506+
CUFFT_CB_UNDEFINED = 8
507+
end
508+
509+
const cufftXtCallbackType = cufftXtCallbackType_t
510+
511+
# typedef cufftComplex ( * cufftCallbackLoadC ) ( void * dataIn , size_t offset , void * callerInfo , void * sharedPointer )
512+
const cufftCallbackLoadC = Ptr{Cvoid}
513+
514+
# typedef cufftDoubleComplex ( * cufftCallbackLoadZ ) ( void * dataIn , size_t offset , void * callerInfo , void * sharedPointer )
515+
const cufftCallbackLoadZ = Ptr{Cvoid}
516+
517+
# typedef cufftReal ( * cufftCallbackLoadR ) ( void * dataIn , size_t offset , void * callerInfo , void * sharedPointer )
518+
const cufftCallbackLoadR = Ptr{Cvoid}
519+
520+
# typedef cufftDoubleReal ( * cufftCallbackLoadD ) ( void * dataIn , size_t offset , void * callerInfo , void * sharedPointer )
521+
const cufftCallbackLoadD = Ptr{Cvoid}
522+
523+
# typedef void ( * cufftCallbackStoreC ) ( void * dataOut , size_t offset , cufftComplex element , void * callerInfo , void * sharedPointer )
524+
const cufftCallbackStoreC = Ptr{Cvoid}
525+
526+
# typedef void ( * cufftCallbackStoreZ ) ( void * dataOut , size_t offset , cufftDoubleComplex element , void * callerInfo , void * sharedPointer )
527+
const cufftCallbackStoreZ = Ptr{Cvoid}
528+
529+
# typedef void ( * cufftCallbackStoreR ) ( void * dataOut , size_t offset , cufftReal element , void * callerInfo , void * sharedPointer )
530+
const cufftCallbackStoreR = Ptr{Cvoid}
531+
532+
# typedef void ( * cufftCallbackStoreD ) ( void * dataOut , size_t offset , cufftDoubleReal element , void * callerInfo , void * sharedPointer )
533+
const cufftCallbackStoreD = Ptr{Cvoid}
534+
535+
@checked function cufftXtSetCallback(plan, callback_routine, cbType, caller_info)
536+
initialize_context()
537+
@gcsafe_ccall libcufft.cufftXtSetCallback(plan::cufftHandle,
538+
callback_routine::Ptr{Ptr{Cvoid}},
539+
cbType::cufftXtCallbackType,
540+
caller_info::Ptr{Ptr{Cvoid}})::cufftResult
541+
end
542+
543+
@checked function cufftXtClearCallback(plan, cbType)
544+
initialize_context()
545+
@gcsafe_ccall libcufft.cufftXtClearCallback(plan::cufftHandle,
546+
cbType::cufftXtCallbackType)::cufftResult
547+
end
548+
549+
@checked function cufftXtSetCallbackSharedSize(plan, cbType, sharedSize)
550+
initialize_context()
551+
@gcsafe_ccall libcufft.cufftXtSetCallbackSharedSize(plan::cufftHandle,
552+
cbType::cufftXtCallbackType,
553+
sharedSize::Csize_t)::cufftResult
554+
end
555+
556+
@checked function cufftXtMakePlanMany(plan, rank, n, inembed, istride, idist, inputtype,
557+
onembed, ostride, odist, outputtype, batch, workSize,
558+
executiontype)
559+
initialize_context()
560+
@gcsafe_ccall libcufft.cufftXtMakePlanMany(plan::cufftHandle, rank::Cint,
561+
n::Ptr{Clonglong}, inembed::Ptr{Clonglong},
562+
istride::Clonglong, idist::Clonglong,
563+
inputtype::cudaDataType,
564+
onembed::Ptr{Clonglong}, ostride::Clonglong,
565+
odist::Clonglong, outputtype::cudaDataType,
566+
batch::Clonglong, workSize::Ptr{Csize_t},
567+
executiontype::cudaDataType)::cufftResult
568+
end
569+
570+
@checked function cufftXtGetSizeMany(plan, rank, n, inembed, istride, idist, inputtype,
571+
onembed, ostride, odist, outputtype, batch, workSize,
572+
executiontype)
573+
initialize_context()
574+
@gcsafe_ccall libcufft.cufftXtGetSizeMany(plan::cufftHandle, rank::Cint,
575+
n::Ptr{Clonglong}, inembed::Ptr{Clonglong},
576+
istride::Clonglong, idist::Clonglong,
577+
inputtype::cudaDataType,
578+
onembed::Ptr{Clonglong}, ostride::Clonglong,
579+
odist::Clonglong, outputtype::cudaDataType,
580+
batch::Clonglong, workSize::Ptr{Csize_t},
581+
executiontype::cudaDataType)::cufftResult
582+
end
583+
584+
@checked function cufftXtExec(plan, input, output, direction)
585+
initialize_context()
586+
@gcsafe_ccall libcufft.cufftXtExec(plan::cufftHandle, input::CuPtr{Cvoid},
587+
output::CuPtr{Cvoid}, direction::Cint)::cufftResult
588+
end
589+
590+
@checked function cufftXtExecDescriptor(plan, input, output, direction)
591+
initialize_context()
592+
@gcsafe_ccall libcufft.cufftXtExecDescriptor(plan::cufftHandle,
593+
input::Ptr{cudaLibXtDesc},
594+
output::Ptr{cudaLibXtDesc},
595+
direction::Cint)::cufftResult
596+
end
597+
598+
@checked function cufftXtSetWorkAreaPolicy(plan, policy, workSize)
599+
initialize_context()
600+
@gcsafe_ccall libcufft.cufftXtSetWorkAreaPolicy(plan::cufftHandle,
601+
policy::cufftXtWorkAreaPolicy,
602+
workSize::Ptr{Csize_t})::cufftResult
603+
end
604+
605+
const CUDA_XT_DESCRIPTOR_VERSION = 0x01000000
606+
607+
const MAX_CUDA_DESCRIPTOR_GPUS = 64
608+
324609
# Skipping MacroDefinition: CUFFTAPI __attribute__ ( ( visibility ( "default" ) ) )
325610

326611
const MAX_CUFFT_ERROR = 0x11

lib/cufft/util.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,20 @@
1-
const cufftNumber = Union{cufftDoubleReal,cufftReal,cufftDoubleComplex,cufftComplex}
2-
const cufftReals = Union{cufftDoubleReal,cufftReal}
3-
const cufftComplexes = Union{cufftDoubleComplex,cufftComplex}
4-
const cufftDouble = Union{cufftDoubleReal,cufftDoubleComplex}
5-
const cufftSingle = Union{cufftReal,cufftComplex}
6-
const cufftTypeDouble = Union{Type{cufftDoubleReal},Type{cufftDoubleComplex}}
7-
const cufftTypeSingle = Union{Type{cufftReal},Type{cufftComplex}}
1+
const cufftReals = Union{cufftDoubleReal,cufftReal,Float16}
2+
const cufftComplexes = Union{cufftDoubleComplex,cufftComplex,Complex{Float16}}
3+
const cufftNumber = Union{cufftReals,cufftComplexes}
84

95
cufftfloat(x) = _cufftfloat(float(x))
10-
_cufftfloat(::Type{T}) where {T<:cufftReals} = T
11-
_cufftfloat(::Type{Float16}) = Float32
12-
_cufftfloat(::Type{Complex{T}}) where {T} = Complex{_cufftfloat(T)}
6+
_cufftfloat(::Type{T}) where {T<:cufftNumber} = T
137
_cufftfloat(::Type{T}) where {T} = error("type $T not supported")
148
_cufftfloat(x::T) where {T} = _cufftfloat(T)(x)
159

16-
complexfloat(x::DenseCuArray{Complex{<:cufftReals}}) = x
1710
realfloat(x::DenseCuArray{<:cufftReals}) = x
11+
realfloat(x::DenseCuArray{T}) where {T<:Real} = copy1(cufftfloat(T), x)
12+
realfloat(x::DenseCuArray{T}) where {T} = error("type $T not supported")
1813

19-
complexfloat(x::DenseCuArray{T}) where {T<:Complex} = copy1(typeof(cufftfloat(zero(T))), x)
20-
complexfloat(x::DenseCuArray{T}) where {T<:Real} = copy1(typeof(complex(cufftfloat(zero(T)))), x)
21-
22-
realfloat(x::DenseCuArray{T}) where {T<:Real} = copy1(typeof(cufftfloat(zero(T))), x)
14+
complexfloat(x::DenseCuArray{<:cufftComplexes}) = x
15+
complexfloat(x::DenseCuArray{T}) where {T<:Complex} = copy1(cufftfloat(T), x)
16+
complexfloat(x::DenseCuArray{T}) where {T<:Real} = copy1(cufftfloat(complex(T)), x)
17+
complexfloat(x::DenseCuArray{T}) where {T} = error("type $T not supported")
2318

2419
function copy1(::Type{T}, x) where T
2520
y = CuArray{T}(undef, map(length, axes(x)))

0 commit comments

Comments
 (0)