-
Notifications
You must be signed in to change notification settings - Fork 3
Description
One major difficulty of using CUDA.jl + anything Bayesian is that you immediately need to define all the (at very least, univariate) distributions all over again. But for starters we don't need all of the functionality of a Distribution
to do something like Bayesian inference, e.g. for AdvancedHMC.jl we really only need logpdf
and its adjoint as the only call to rand
is going to be for the momentum (which can be sampled directly on the GPU using CUDA.randn
).
But even trying to redefine the logpdf
of a Distribution
to work on the GPU is often non-trivial.
Issue #1
Before going into the issue, it's important to know the following:
- All functions that you want to vectorize/broadcast needs to be replaced with the corresponding CUDA function, e.g.
Base.log
andCUDA.log
are actually different methods. - AFAIK, CUDA.jl achieves this by effectively converting all expressions of the form
f.(args...)
toCUDA.cufunc(f).(args...)
whenever any of theargs
is aCUDA.CuArray
. by overloading theBroadcast.broadcasted
. E.g.CUDA.cufunc(::typeof(log)) = CUDA.log
. - Therefore, in the case where a function
f
does not have acufunc
already defined and you dof.(args...)
you'll, if you're lucky, get an error but sometimes the entire Julia session will crash.
So what do we do?
- Well, we define a
cufunc(::typeof(f)) = ...
which will allow you to broadcast over- simple custom-defined functions (e.g.
digamma
from this PR by @xukai92: Gamma family function support JuliaGPU/CuArrays.jl#321 (comment)) - other native CUDA-functions, e.g.
loggamma
can be replaced byCUDA.lgamma
.
- simple custom-defined functions (e.g.
- Great stuff! Then you go and try to take the gradient of this thing and everything breaks again. So you need to also define define rules for this
f
.- AFAIK, for
f
on GPU Zygote.jl uses ForwardDiff.jl to obtain the adjoints for broadcasting and so we gotta define these rules usingDiffRules
and evaluate usingForwardDiff
. - You then try to figure out where the current
DiffRules
definition is forf
and you copy-paste, replacing methods withcufunc
methods.
- AFAIK, for
- Okay, so at this point we have all these lower-level functions with their corresponding
cufunc
definitions and their corresponding@define_diffrule
, so we're good to go right? Now we can just callStatsFuns.gammalogpdf.(α, θ, x)
, right?- Answer: No. AFAIK, there are several things that can fail:
A lot of the functions used within the Distributions.jl ecosystem is not pure Julia under the hood, but there are often pure-Julia versions for more generic number types so that one for example can just AD through them without issues, e.g. https://github.com/JuliaStats/StatsFuns.jl/blob/8dfda2c0ee33d5f85eca5c039d31d85c90f363f2/src/distrs/gamma.jl#L19. BUT this doesn't help compat with CUDA.jl because elementypes of aEDIT: this is only an issue for eltypeCUDA.CuArray
aren't special, i.e. it's just aFloat32
. And so the function we dispatch on when broadcasting over aCUDA.CuArray
will be some function outside of the Julia ecosystem, and so things starts blowing up.Float64
, notFloat32
as pointed out by @devmotion!- Observing this, we specialize by overloading the method for
Float32
and so on to use the pure-Julia implementation, e.g.StatsFuns.gammalogpdf(k::Float32, θ::Float32, x::Float32) = -SpecialFunctions.loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
- BUT it still fails because the overloading of
broadcasted
won't be properly nested, socufunc
will only be called onStatsFuns.gammalogpdf
, not on the methods used within! So, we instead dowhich is really not fun.cugammalogpdf(k, θ, x) = -CUDA.cufunc(SpecialFunctions.loggamma)(k) - k * CUDA.cufunc(log)(θ) + (k - 1) * CUDA.cufunc(log)(x) - x / θ CUDA.cufunc(::typeof(StatsFuns.gammalogpdf)) = cugammalogpdf
- I.e. it's not enough to define
cufunc
for the leaves of the method hierarchy! This sucks. - Upshot: for most AD-frameworks (Zygote.jl and ForwardDiff.jl because Zygote.jl uses ForwardDiff.jl for broadcasting) we get the AD-rules for all non-leaves in the method hierarchy this way.
- Answer: No. AFAIK, there are several things that can fail:
- Of course there is work on doing automatic method substitution: Method overrides using a method overlay table. JuliaGPU/GPUCompiler.jl#122, but in the mean-time we need to do the above process.
Potential solutions
- Do it by hand, but with the help of an "improved" version of the
CUDA.@cufunc
macro that I've implemented (https://github.com/JuliaGPU/CUDA.jl/blob/c011ffc0971ab1089f9d56dd338ef4b31e24ecc7/src/broadcast.jl#L101-L112) which has the following additional features:- Replaces all functions
f
in the body withcufunc(f)
, with the default impl ofcufunc(f) = f
. I.e. do nothing to almost all methods, but those which have acufunc
impl we replace. - Correctly handles namespaces, e.g.
@cufunc SpecialFunctions.gamma(x) = ...
is converted intocugamma(x) = ...; cufunc(::typeof(SpecialFunctions.gamma)) = cugamma
. - [OPTIONAL] If
f
is present inDiffRule.diffrules()
, then we extract the corresponding diffrule and replaces all functionsg
within the diffrule withcufunc(g)
. I.e. IF there is a scalar-rule forf
, then we make it CUDA compatible (assuming the methods in the rule has acufunc
implementation), otherwise we leave it to ForwardDiff.jl to figure it out.
- Replaces all functions
- Wait until the work on method-substitution is done.
Personally, I'm in favour of solution (1).
Issue #2
Now, there's also an additional "annoyance" even after solving the above issue. We cannot do something like logpdf.(Gamma.(α, θ), x)
because this will first to do map(Gamma, ...)
before calling logpdf
. There's the possibility that this could have been inlined into completely removing the call to Gamma
once it's sufficiently lowered, but GPUCompiler.jl
will complain before it reaches that stage (as this is not always a guarantee + I believe it will try to fuse all the broadcasts together into a single operation for efficiency). Therefore we either need to:
- Use the underlying methods directly, e.g.
gammalogpdf.(α, θ, x)
. - Define a
Vectorize(D, args)
, e.g.Vectorize(Gamma, (α, θ))
, which has alogpdf
that lazily calls the underlying method, e.g.logpdf(v::Vectorize{Gamma}, x) = gammalogpdf.(v.args..., x)
. Equipped with this, we can speed up implementation quite a bit by potentially doing something like:- Overload
broadcasted
so that if we're using theCUDA.CuArrayStyle
andf <: UnivariateDistribution
we canmateralize
args
earlier and then wrap it inVectorize
, i.e.Vectorize(f, args)
. - Define a macro similar to
Distributions.@__delegate_statsfuns
or whatever to more easily definelogpdf(v::Vectorize{D}, x)
for differentD
.
Worth mentioning that this requires a small redef of this method in Zygote (https://github.com/FluxML/Zygote.jl/blob/2b17256e79b2eca9a6512207284219d279398fc9/src/lib/broadcast.jl#L225-L228), though it should def. be possible to make it work even though we're overloadingbroadcasted
.
- Overload