Skip to content

To the GPU we go #5

@torfjelde

Description

@torfjelde

One major difficulty of using CUDA.jl + anything Bayesian is that you immediately need to define all the (at very least, univariate) distributions all over again. But for starters we don't need all of the functionality of a Distribution to do something like Bayesian inference, e.g. for AdvancedHMC.jl we really only need logpdf and its adjoint as the only call to rand is going to be for the momentum (which can be sampled directly on the GPU using CUDA.randn).

But even trying to redefine the logpdf of a Distribution to work on the GPU is often non-trivial.

Issue #1

Before going into the issue, it's important to know the following:

  • All functions that you want to vectorize/broadcast needs to be replaced with the corresponding CUDA function, e.g. Base.log and CUDA.log are actually different methods.
  • AFAIK, CUDA.jl achieves this by effectively converting all expressions of the form f.(args...) to CUDA.cufunc(f).(args...) whenever any of the args is a CUDA.CuArray. by overloading the Broadcast.broadcasted. E.g. CUDA.cufunc(::typeof(log)) = CUDA.log.
  • Therefore, in the case where a function f does not have a cufunc already defined and you do f.(args...) you'll, if you're lucky, get an error but sometimes the entire Julia session will crash.

So what do we do?

  • Well, we define a cufunc(::typeof(f)) = ... which will allow you to broadcast over
  • Great stuff! Then you go and try to take the gradient of this thing and everything breaks again. So you need to also define define rules for this f.
    • AFAIK, for f on GPU Zygote.jl uses ForwardDiff.jl to obtain the adjoints for broadcasting and so we gotta define these rules using DiffRules and evaluate using ForwardDiff.
    • You then try to figure out where the current DiffRules definition is for f and you copy-paste, replacing methods with cufunc methods.
  • Okay, so at this point we have all these lower-level functions with their corresponding cufunc definitions and their corresponding @define_diffrule, so we're good to go right? Now we can just call StatsFuns.gammalogpdf.(α, θ, x), right?
    • Answer: No. AFAIK, there are several things that can fail:
      • A lot of the functions used within the Distributions.jl ecosystem is not pure Julia under the hood, but there are often pure-Julia versions for more generic number types so that one for example can just AD through them without issues, e.g. https://github.com/JuliaStats/StatsFuns.jl/blob/8dfda2c0ee33d5f85eca5c039d31d85c90f363f2/src/distrs/gamma.jl#L19. BUT this doesn't help compat with CUDA.jl because elementypes of a CUDA.CuArray aren't special, i.e. it's just a Float32. And so the function we dispatch on when broadcasting over a CUDA.CuArray will be some function outside of the Julia ecosystem, and so things starts blowing up. EDIT: this is only an issue for eltype Float64, not Float32 as pointed out by @devmotion!
      • Observing this, we specialize by overloading the method for Float32 and so on to use the pure-Julia implementation, e.g.
        StatsFuns.gammalogpdf(k::Float32, θ::Float32, x::Float32) = -SpecialFunctions.loggamma(k) - k * log(θ) + (k - 1) * log(x) - x / θ
      • BUT it still fails because the overloading of broadcasted won't be properly nested, so cufunc will only be called on StatsFuns.gammalogpdf, not on the methods used within! So, we instead do
        cugammalogpdf(k, θ, x) = -CUDA.cufunc(SpecialFunctions.loggamma)(k) - k * CUDA.cufunc(log)(θ) + (k - 1) * CUDA.cufunc(log)(x) - x / θ
        CUDA.cufunc(::typeof(StatsFuns.gammalogpdf)) = cugammalogpdf
        which is really not fun.
      • I.e. it's not enough to define cufunc for the leaves of the method hierarchy! This sucks.
      • Upshot: for most AD-frameworks (Zygote.jl and ForwardDiff.jl because Zygote.jl uses ForwardDiff.jl for broadcasting) we get the AD-rules for all non-leaves in the method hierarchy this way.
  • Of course there is work on doing automatic method substitution: Method overrides using a method overlay table. JuliaGPU/GPUCompiler.jl#122, but in the mean-time we need to do the above process.

Potential solutions

  1. Do it by hand, but with the help of an "improved" version of the CUDA.@cufunc macro that I've implemented (https://github.com/JuliaGPU/CUDA.jl/blob/c011ffc0971ab1089f9d56dd338ef4b31e24ecc7/src/broadcast.jl#L101-L112) which has the following additional features:
    • Replaces all functions f in the body with cufunc(f), with the default impl of cufunc(f) = f. I.e. do nothing to almost all methods, but those which have a cufunc impl we replace.
    • Correctly handles namespaces, e.g. @cufunc SpecialFunctions.gamma(x) = ... is converted into cugamma(x) = ...; cufunc(::typeof(SpecialFunctions.gamma)) = cugamma.
    • [OPTIONAL] If f is present in DiffRule.diffrules(), then we extract the corresponding diffrule and replaces all functions g within the diffrule with cufunc(g). I.e. IF there is a scalar-rule for f, then we make it CUDA compatible (assuming the methods in the rule has a cufunc implementation), otherwise we leave it to ForwardDiff.jl to figure it out.
  2. Wait until the work on method-substitution is done.

Personally, I'm in favour of solution (1).

Issue #2

Now, there's also an additional "annoyance" even after solving the above issue. We cannot do something like logpdf.(Gamma.(α, θ), x) because this will first to do map(Gamma, ...) before calling logpdf. There's the possibility that this could have been inlined into completely removing the call to Gamma once it's sufficiently lowered, but GPUCompiler.jl will complain before it reaches that stage (as this is not always a guarantee + I believe it will try to fuse all the broadcasts together into a single operation for efficiency). Therefore we either need to:

  1. Use the underlying methods directly, e.g. gammalogpdf.(α, θ, x).
  2. Define a Vectorize(D, args), e.g. Vectorize(Gamma, (α, θ)), which has a logpdf that lazily calls the underlying method, e.g. logpdf(v::Vectorize{Gamma}, x) = gammalogpdf.(v.args..., x). Equipped with this, we can speed up implementation quite a bit by potentially doing something like:
    1. Overload broadcasted so that if we're using the CUDA.CuArrayStyle and f <: UnivariateDistribution we can materalize args earlier and then wrap it in Vectorize, i.e. Vectorize(f, args).
    2. Define a macro similar to Distributions.@__delegate_statsfuns or whatever to more easily define logpdf(v::Vectorize{D}, x) for different D.
      Worth mentioning that this requires a small redef of this method in Zygote (https://github.com/FluxML/Zygote.jl/blob/2b17256e79b2eca9a6512207284219d279398fc9/src/lib/broadcast.jl#L225-L228), though it should def. be possible to make it work even though we're overloading broadcasted.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions