Skip to content

Add getadtype function to AbstractSampler interface #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 10, 2025

Right now, Turing.jl contains some pretty ad hoc code to determine the underlying AD backend for a sampler:

https://github.com/TuringLang/Turing.jl/blob/ddd74b1cf4694b66698f225db549d2cfb5eb826c/src/mcmc/Inference.jl#L181-L189

In the process, it does a bunch of type piracy and this definition also makes it hard to move any of this behaviour to DynamicPPL.

In particular, it actually means that the AD backend (inside a LogDensityProblemsAD.ADgradient object) is only specified by code inside Turing.jl, even though DynamicPPL should have enough information to infer the AD type from the sampler it's given.

This PR proposes to move the AD type specification to the lower layer of AbstractMCMC so that this information is accessible across more layers.

It's not 100% obvious to me that gradient-based samplers necessarily need to declare their adtype (maybe depending on the sampler implementation, it might be declared elsewhere, e.g. in the model), but having getadtype be a property of the sampler does at least match what we currently do.

Followups if this is merged

  • Implement getadtype(::DynamicPPL.Sampler)
  • Remove extra code from Turing
  • (lower priority) Implement in AdvancedHMC

@yebai
Copy link
Member

yebai commented Feb 10, 2025

It feels slightly unusual to attach ad type to samplers—we do this in Turing, but I'm not sure it's the best design.

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 10, 2025

I somewhat agree - that's why I left it as optional (defaulting to nothing). I think the sampler is probably the 'least bad place' to put it - the other options are the model (weird) or explicitly in the evaluation context (might be better, but ugly if we force users to specify AD type via context).

Edit: I no longer think it's weird to put it in the model, see comment below

@yebai
Copy link
Member

yebai commented Feb 10, 2025

The combination of logdensitymodel and sampler jointly determines the adtype. So, I'd suggest we use getadtype(LogDensityModel, Sampler) instead of getadtype(Sampler) for the interface.

@penelopeysm
Copy link
Member Author

So, I'd suggest we use getadtype(LogDensityModel, Sampler) instead of getadtype(Sampler) for the interface

I support that:) One thing, we don't actually use LogDensityModel in Turing, so I think we want to keep it general with

getadtype(::AbstractModel, ::AbstractSampler)

And I wonder if we could also maybe use the same strategy as getparams and have a default implementation

getadtype(::AbstractModel, spl::AbstractSampler) = getadtype(spl)

and let users overload the method they prefer to use?

@mhauru
Copy link
Member

mhauru commented Feb 12, 2025

I'd be happy with the

getadtype(::AbstractModel, spl::AbstractSampler) = getadtype(spl)

approach.

What's the convention for dependencies of interface packages? I feel a bit awkward about explicitly requiring that the return type is ADTypes.AbstactADType, but not having ADTypes as a dependency. If Julia would allow us to enforce the interface, we definitely would have it as a dependency. But can also see this both ways, not wanting to add a dependency that the code doesn't use (only the docs do).

@penelopeysm penelopeysm marked this pull request as draft February 12, 2025 13:23
@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 12, 2025

The more I think about it, the more I'm convinced that the adtype should be in the model, not the sampler. But at the same time I don't know if it would be worth it to break everything :/

(1)

From a gradient-based sampler's point of view, it should be agnostic to how the gradient is calculated, it just cares that it can have a gradient. For example, if the model obeys the LogDensityProblems interface, the sampler only really cares that logdensity_and_gradient(model, params) works; it shouldn't care about the implementation of this.

Indeed, one might find that for a particular model it is possible to implement logdensity_and_gradient analytically, in which case we can still use HMC just fine even without any AD backend. (The AdvancedHMC-specific sample interface allows for this, but not Turing.HMC.)

(2)

The best AD backend to use is almost entirely determined by the model (e.g. the number of parameters, or the functionality inside the model) and thus it makes sense that the choice of AD backend should be tied to the model.

(3)

In DynamicPPL we might often want to test AD correctness and/or performance with a model, in a way like this

@model f() = ...

ref_logp, ref_grad = logdensity_and_gradient_with(f(), AutoForwardDiff())
logp, grad = logdensity_and_gradient_with(f(), AutoMooncake())\
@test isapprox(logp, ref_logp)
@test isapprox(grad, ref_grad)

We definitely don't want to have to construct a dummy sampler just to pass the adtype to the model.

(4)

Currently the way that Turing uses AD is this:

  • wrap the sampler in a SamplingContext
  • wrap the model and the context in a LogDensityFunction
  • when we need to calculate logdensity_and_gradient, get the context of the LogDensityFunction, get the sampler of the context, and get the AD type of the sampler
  • construct a LogDensityProblemsAD.ADgradient from that and finally calculate the gradient

Specifying the AD type as part of the model will let us simplify this loop and in particular it will enable things like (3) above.

(5)

If we say that it's part of the model, then actually this whole PR does not need to exist. It can just be part of the interface of DynamicPPL.Model rather than AbstractModel, because as pointed out in point (1) above, not all models necessarily need to use AD to calculate the gradient.

Problems

The obvious, huge, problem with this is that we would need to rework the whole sampling interface. e.g. instead of

sample(f(), NUTS(; adtype=AutoMooncake())

we would need to write something like

sample(setadtype(f(), AutoMooncake()), NUTS())

In an ideal world I would actually prefer this. Of course, this would also be horribly breaking.

The other drawback is that we can't mix and match adtypes, so we can't do this:

sample(f(), Gibbs(x=NUTS(; adtype=AutoForwardDiff()), y=NUTS(; adtype=AutoMooncake())))

I don't see a convincing reason why someone would want to do this, though, so I don't feel bad about removing this capability.

The final implication is that the default of ForwardDiff would have to be pushed up to DynamicPPL rather than sitting in Turing. I'm also quite happy for this to be the case because that does reflect our current position (i.e. running AD on DynamicPPL.Model is 'supposed' to be correct with ForwardDiff, and it's the reference against which we compare other AD backends).

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 12, 2025

I have something of a plan to gradually shift us from declaring adtype in samplers to declaring adtype in models, i.e. to go from

sample(f(), NUTS(; adtype=AutoMooncake())

to

sample(setadtype(f(), AutoMooncake()), NUTS())

Let's discuss this at a Monday meeting, I think we should definitely try to get it right, because this is one of the areas where we have awkward entanglement between model and sampler.

@yebai
Copy link
Member

yebai commented Feb 12, 2025

I agree that adtypes should be attached to models instead of inference algorithms. One idea that I had:

sample(f(), Gibbs(x=NUTS(), y=NUTS()), Autograd(x=AutoForwardDiff(), y=AutoMooncake()))

or, more concisely (also automatic consistency guarantees of parameters partitioning between Gibbs and Autograd):

sample(f(), Gibbs(x=(NUTS(), AutoMooncake()), y=(NUTS(), AutoMooncake())))

EDIT: we could lower

  • sample(f(), NUTS(; adtype=AutoMooncake()) to sample(f(), NUTS(), AutoMooncake())
  • sample(f(), Gibbs(x=NUTS(; adtype=AutoForwardDiff()), y=NUTS(; adtype=AutoMooncake()))) to sample(f(), Gibbs(x=NUTS(), y=NUTS()), Autograd(x=AutoForwardDiff(), y=AutoMooncake()))

internally for backwards compatibility.

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 14, 2025

I no longer think this PR is the right way to do it, so closing. But I will add this general discussion to the agenda for Monday's meeting.

@penelopeysm penelopeysm deleted the py/getadtype branch February 14, 2025 19:13
Copy link

codecov bot commented Feb 14, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 0.00%. Comparing base (3fffd3e) to head (ae297c3).

Additional details and impacted files
@@      Coverage Diff      @@
##   master   #158   +/-   ##
=============================
=============================

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants