Closed
Description
MWE (make sure to use breaking
branch, or something with accumulators)
using DynamicPPL, Distributions
using ADTypes: AutoEnzyme
import Enzyme: Enzyme, Forward, Reverse, set_runtime_activity, Const
@model f() = x ~ Normal()
model = f()
vi = VarInfo(model)
ctx = DefaultContext()
params = [0.5]
function logp(x, model, vi, ctx)
varinfo_new = DynamicPPL.unflatten(vi, x)
varinfo_eval = last(DynamicPPL.evaluate!!(model, varinfo_new, ctx))
return getlogjoint(varinfo_eval)
end
logp(params, model, vi, ctx)
Enzyme.gradient(set_runtime_activity(Forward), logp, params, Const(model), Const(vi), Const(ctx))
Enzyme.gradient(set_runtime_activity(Reverse), logp, params, Const(model), Const(vi), Const(ctx))
Metadata
Metadata
Assignees
Labels
No labels