-
Notifications
You must be signed in to change notification settings - Fork 19
Addition of step_warmup
#117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
0987f5f
30c9f12
7faa73f
bd0bdc7
c620cca
572a286
6b842ee
ca03832
0441773
ea369ff
8e0ca53
6877978
b3b3148
ddc5254
ffbd32f
87480ff
76f2f23
c00d0c9
ef09c19
49b8406
f005746
9dccd8a
ff00e6e
7ce9f6b
3a217b2
de9bb2c
85d938f
0a667a4
91f5a10
7603171
ef68d04
25afc66
1886fa8
0ea293a
6e8f88e
44c55bb
3b4f6db
f9142a6
295fdc1
e6acb1f
2e9fa5c
366fceb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -91,7 +91,29 @@ function StatsBase.sample( | |||||||||||
end | ||||||||||||
|
||||||||||||
# Default implementations of regular and parallel sampling. | ||||||||||||
|
||||||||||||
""" | ||||||||||||
mcmcsample(rng, model, sampler, N_or_is_done; kwargs...) | ||||||||||||
|
||||||||||||
Default implementation of `sample` for a `model` and `sampler`. | ||||||||||||
|
||||||||||||
# Arguments | ||||||||||||
- `rng::Random.AbstractRNG`: the random number generator to use. | ||||||||||||
- `model::AbstractModel`: the model to sample from. | ||||||||||||
- `sampler::AbstractSampler`: the sampler to use. | ||||||||||||
- `N::Integer`: the number of samples to draw. | ||||||||||||
|
||||||||||||
# Keyword arguments | ||||||||||||
- `progress`: whether to display a progress bar. Defaults to `true`. | ||||||||||||
- `progressname`: the name of the progress bar. Defaults to `"Sampling"`. | ||||||||||||
- `callback`: a function that is called after each [`AbstractMCMC.step`](@ref). | ||||||||||||
Defaults to `nothing`. | ||||||||||||
- `num_warmup`: number of warmup samples to draw. Defaults to `0`. | ||||||||||||
- `discard_initial`: number of initial samples to discard. Defaults to `num_warmup`. | ||||||||||||
- `thinning`: number of samples to discard between samples. Defaults to `1`. | ||||||||||||
- `chain_type`: the type to pass to [`AbstractMCMC.bundle_samples`](@ref) at the | ||||||||||||
end of sampling to wrap up the resulting samples nicely. Defaults to `Any`. | ||||||||||||
- `kwargs...`: Additional keyword arguments to pass on to [`AbstractMCMC.step`](@ref). | ||||||||||||
""" | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think nobody will look up the docstring for the unexported There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aaah I was totally unaware! So I removed this, and then I've just added a section to |
||||||||||||
function mcmcsample( | ||||||||||||
rng::Random.AbstractRNG, | ||||||||||||
model::AbstractModel, | ||||||||||||
|
@@ -100,14 +122,21 @@ function mcmcsample( | |||||||||||
progress=PROGRESS[], | ||||||||||||
progressname="Sampling", | ||||||||||||
callback=nothing, | ||||||||||||
discard_initial=0, | ||||||||||||
num_warmup=0, | ||||||||||||
discard_initial=num_warmup, | ||||||||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
thinning=1, | ||||||||||||
chain_type::Type=Any, | ||||||||||||
kwargs..., | ||||||||||||
) | ||||||||||||
# Check the number of requested samples. | ||||||||||||
N > 0 || error("the number of samples must be ≥ 1") | ||||||||||||
Ntotal = thinning * (N - 1) + discard_initial + 1 | ||||||||||||
Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1 | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? Shouldn't it just stay the same, possibly with some additional checks:
Suggested change
I thought we would do the following:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, yes this was left over from my initial implementation that treated
Agreed, but isn't this what my impl is currently doing? With the exception of this line above of course. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know, I stopped reviewing at this point and didn't check the rest 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haha, aight. Well, the rest is supposed to implement exactly what you outlined 😅 I'll see if I can also add some tests. |
||||||||||||
|
||||||||||||
# Determine how many samples to drop from `num_warmup` and the | ||||||||||||
# main sampling process before we start saving samples. | ||||||||||||
discard_from_warmup = min(num_warmup, discard_initial) | ||||||||||||
keep_from_warmup = num_warmup - discard_from_warmup | ||||||||||||
discard_from_sample = max(discard_initial - discard_from_warmup, 0) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this seems to match what I wrote above. |
||||||||||||
|
||||||||||||
# Start the timer | ||||||||||||
start = time() | ||||||||||||
|
@@ -124,34 +153,76 @@ function mcmcsample( | |||||||||||
# Obtain the initial sample and state. | ||||||||||||
sample, state = step(rng, model, sampler; kwargs...) | ||||||||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
|
||||||||||||
# Discard initial samples. | ||||||||||||
for i in 1:discard_initial | ||||||||||||
# Update the progress bar. | ||||||||||||
if progress && i >= next_update | ||||||||||||
ProgressLogging.@logprogress i / Ntotal | ||||||||||||
next_update = i + threshold | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Warmup sampling. | ||||||||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
for _ in 1:discard_from_warmup | ||||||||||||
# Obtain the next sample and state. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these should be accounted for in the progress logger as well (as done currently). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be good now 👍 |
||||||||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||||||||
sample, state = step_warmup(rng, model, sampler, state; kwargs...) | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Run callback. | ||||||||||||
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) | ||||||||||||
i = 1 | ||||||||||||
if keep_from_warmup > 0 | ||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
|
||||||||||||
# Step through remainder of warmup iterations and save. | ||||||||||||
i += 1 | ||||||||||||
for _ in (discard_from_warmup + 1):num_warmup | ||||||||||||
# Update the progress bar. | ||||||||||||
if progress && i >= next_update | ||||||||||||
ProgressLogging.@logprogress i / Ntotal | ||||||||||||
next_update = i + threshold | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Obtain the next sample and state. | ||||||||||||
sample, state = step_warmup(rng, model, sampler, state; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) | ||||||||||||
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) | ||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
i += 1 | ||||||||||||
end | ||||||||||||
else | ||||||||||||
# Discard additional initial samples, if needed. | ||||||||||||
for _ in 1:discard_from_sample | ||||||||||||
# Update the progress bar. | ||||||||||||
if progress && i >= next_update | ||||||||||||
ProgressLogging.@logprogress i / Ntotal | ||||||||||||
next_update = i + threshold | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Obtain the next sample and state. | ||||||||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
|
||||||||||||
# Increment iteration number. | ||||||||||||
i += 1 | ||||||||||||
end | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be easier to do something along the lines of # Discard initial samples, if needed.
for _ in 1:discard_initial
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
end
# Obtain the next sample and state.
sample, state = if i <= num_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
end
...
# Increment iteration number.
i += 1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well... Yes 🤦 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah actually, I now remember the reason for why I didn't do this: won't this make it type-unstable while the current implementation won't (if indeed Whether we care is another thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My worry was that too many if-staments involving runtime information could confuse the type-inference, even in the case where the return-types of each branch is the same, but from your comments I'm assuming this was an irrational fear:) Also, I just read a bit more https://juliahub.com/blog/2016/04/inference-convergence/ and my fears are indeed irrational 👍 I guess as long as each of the if-statement returns the same two types, i.e. the |
||||||||||||
|
||||||||||||
# Update the progress bar. | ||||||||||||
itotal = 1 + discard_initial | ||||||||||||
itotal = i | ||||||||||||
if progress && itotal >= next_update | ||||||||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||||||||
next_update = itotal + threshold | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Step through the sampler. | ||||||||||||
for i in 2:N | ||||||||||||
while i ≤ N | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason to switch to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah no! I'll revert it to for-loop 👍 |
||||||||||||
# Discard thinned samples. | ||||||||||||
for _ in 1:(thinning - 1) | ||||||||||||
# Obtain the next sample and state. | ||||||||||||
|
@@ -174,6 +245,9 @@ function mcmcsample( | |||||||||||
# Save the sample. | ||||||||||||
samples = save!!(samples, sample, i, model, sampler, N; kwargs...) | ||||||||||||
|
||||||||||||
# Increment iteration counter. | ||||||||||||
i += 1 | ||||||||||||
|
||||||||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
# Update the progress bar. | ||||||||||||
if progress && (itotal += 1) >= next_update | ||||||||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||||||||
|
@@ -209,10 +283,16 @@ function mcmcsample( | |||||||||||
progress=PROGRESS[], | ||||||||||||
progressname="Convergence sampling", | ||||||||||||
callback=nothing, | ||||||||||||
discard_initial=0, | ||||||||||||
num_warmup=0, | ||||||||||||
discard_initial=num_warmup, | ||||||||||||
thinning=1, | ||||||||||||
kwargs..., | ||||||||||||
) | ||||||||||||
# Determine how many samples to drop from `num_warmup` and the | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add the same/similar error checks as above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done 👍 |
||||||||||||
# main sampling process before we start saving samples. | ||||||||||||
discard_from_warmup = min(num_warmup, discard_initial) | ||||||||||||
keep_from_warmup = num_warmup - discard_from_warmup | ||||||||||||
discard_from_sample = max(discard_initial - discard_from_warmup, 0) | ||||||||||||
|
||||||||||||
# Start the timer | ||||||||||||
start = time() | ||||||||||||
|
@@ -222,21 +302,54 @@ function mcmcsample( | |||||||||||
# Obtain the initial sample and state. | ||||||||||||
sample, state = step(rng, model, sampler; kwargs...) | ||||||||||||
|
||||||||||||
# Discard initial samples. | ||||||||||||
for _ in 1:discard_initial | ||||||||||||
# Warmup sampling. | ||||||||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
for _ in 1:discard_from_warmup | ||||||||||||
# Obtain the next sample and state. | ||||||||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||||||||
sample, state = step_warmup(rng, model, sampler, state; kwargs...) | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Run callback. | ||||||||||||
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) | ||||||||||||
i = 1 | ||||||||||||
if keep_from_warmup > 0 | ||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) | ||||||||||||
samples = save!!(samples, sample, 1, model, sampler; kwargs...) | ||||||||||||
# Step through remainder of warmup iterations and save. | ||||||||||||
i += 1 | ||||||||||||
for _ in (discard_from_warmup + 1):num_warmup | ||||||||||||
# Obtain the next sample and state. | ||||||||||||
sample, state = step_warmup(rng, model, sampler, state; kwargs...) | ||||||||||||
|
||||||||||||
# Step through the sampler until stopping. | ||||||||||||
i = 2 | ||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
i += 1 | ||||||||||||
end | ||||||||||||
else | ||||||||||||
# Discard additional initial samples, if needed. | ||||||||||||
for _ in 1:discard_from_sample | ||||||||||||
# Obtain the next sample and state. | ||||||||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||||||||
end | ||||||||||||
|
||||||||||||
# Run callback. | ||||||||||||
callback === nothing || | ||||||||||||
callback(rng, model, sampler, sample, state, i; kwargs...) | ||||||||||||
|
||||||||||||
# Save the sample. | ||||||||||||
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...) | ||||||||||||
samples = save!!(samples, sample, i, model, sampler; kwargs...) | ||||||||||||
|
||||||||||||
# Increment iteration number. | ||||||||||||
i += 1 | ||||||||||||
end | ||||||||||||
|
||||||||||||
while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) | ||||||||||||
# Discard thinned samples. | ||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.