Skip to content

Remove initialstep, rework default_varinfo #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

Other changes in service of TuringLang/Turing.jl#2555 --- but I think these also make life easier.

The point was that AbstractMCMC.step(rng, ::Model, ::DynamicPPL.Sampler) would do a few things:

  1. It would create a default_varinfo for the model (which meant typed varinfo, unless you specifically overrode it)
  2. It would then initialise the parameters in the varinfo
  3. Finally, it would call DynamicPPL.initialstep(...) with that new varinfo

Because I'm now changing sample and step to work with LogDensityFunctions rather than Model, this means that the work to create the varinfo has to be done before calling step. See:

https://github.com/TuringLang/Turing.jl/blob/9054b0af3f6cd625c1e981b7b96c800ade63ca3f/src/mcmc/sample.jl#L186-L200

So, this PR:

  • Moves the parameter initialisation into default_varinfo, and adds an argument to link the varinfo (which Turing samplers can then make use of, by declaring that they need unconstrained space)
  • Since default_varinfo is called higher up, it no longer needs to be part of AbstractMCMC.step, and thus initialstep is no longer necessary.

Samplers that implemented DynamicPPL.initialstep(rng, model, spl, vi, ...) should just implement AbstractMCMC.step(rng, ldf, spl) in exactly the same way.

Note that all of these are merely interface changes -- they don't actually affect any code in DynamicPPL as none of the actual sampling is implemented here.

I'm not sure if this should be a breaking change. Technically, none of these functions were exported, so I have erred on the side of danger and marked it as a patch.

Copy link
Contributor

github-actions bot commented May 25, 2025

Benchmark Report for Commit ef6522f

Computer Information

Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.7 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                730.2 |                35.9 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                429.9 |                46.1 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1250.8 |                27.9 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3408.4 |                24.2 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1480.3 |                30.4 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                966.9 |                 5.3 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5551.6 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1048.3 |                 8.7 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              63021.4 |                 3.8 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               9335.6 |                 9.6 |
|               Dynamic |        10 |    mooncake |             typed |   true |                132.4 |                14.4 |
|              Submodel |         1 |    mooncake |             typed |   true |                 14.4 |                 6.3 |
|                   LDA |        12 | reversediff |             typed |   true |                456.9 |                 6.1 |

@penelopeysm penelopeysm marked this pull request as draft May 25, 2025 19:04
@penelopeysm
Copy link
Member Author

This is going to be messy....

Copy link

codecov bot commented May 25, 2025

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Project coverage is 34.11%. Comparing base (a8a7026) to head (ef6522f).

Files with missing lines Patch % Lines
src/sampler.jl 0.00% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (a8a7026) and HEAD (ef6522f). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (a8a7026) HEAD (ef6522f)
12 6
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #938       +/-   ##
===========================================
- Coverage   82.91%   34.11%   -48.80%     
===========================================
  Files          36       36               
  Lines        3962     3928       -34     
===========================================
- Hits         3285     1340     -1945     
- Misses        677     2588     +1911     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented May 25, 2025

Pull Request Test Coverage Report for Build 15240990436

Details

  • 0 of 2 (0.0%) changed or added relevant lines in 1 file are covered.
  • 36 unchanged lines in 6 files lost coverage.
  • Overall coverage decreased (-48.9%) to 34.158%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/sampler.jl 0 2 0.0%
Files with Coverage Reduction New Missed Lines %
ext/DynamicPPLForwardDiffExt.jl 1 63.64%
src/contexts.jl 2 26.18%
src/model.jl 5 39.17%
src/simple_varinfo.jl 5 34.68%
src/logdensityfunction.jl 8 63.83%
src/compiler.jl 15 63.11%
Totals Coverage Status
Change from base Build 15240387609: -48.9%
Covered Lines: 1340
Relevant Lines: 3923

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants