Skip to content

Commit 3ffb003

Browse files
authored
Miscellaneous style and docs improvements (#622)
* Fix docstring typo * Add mention of context in the docstring of Model * Add a docstring for DynamicTransformationContext * Tiny style improvements
1 parent 122ecd1 commit 3ffb003

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

src/contexts.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte
188188
"""
189189
struct DefaultContext <: AbstractContext end
190190
191-
The `DefaultContext` is used by default to compute log the joint probability of the data
191+
The `DefaultContext` is used by default to compute the log joint probability of the data
192192
and parameters when running the model.
193193
"""
194194
struct DefaultContext <: AbstractContext end
@@ -199,7 +199,7 @@ NodeTrait(context::DefaultContext) = IsLeaf()
199199
vars::Tvars
200200
end
201201
202-
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
202+
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
203203
running the model.
204204
"""
205205
struct PriorContext{Tvars} <: AbstractContext
@@ -213,8 +213,8 @@ NodeTrait(context::PriorContext) = IsLeaf()
213213
vars::Tvars
214214
end
215215
216-
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
217-
running the model. `vars` can be used to evaluate the log likelihood for specific values
216+
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
217+
running the model. `vars` can be used to evaluate the log likelihood for specific values
218218
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
219219
"""
220220
struct LikelihoodContext{Tvars} <: AbstractContext
@@ -229,10 +229,10 @@ NodeTrait(context::LikelihoodContext) = IsLeaf()
229229
loglike_scalar::T
230230
end
231231
232-
The `MiniBatchContext` enables the computation of
233-
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
234-
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
235-
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
232+
The `MiniBatchContext` enables the computation of
233+
`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the
234+
`loglike_scalar` field, typically equal to `the number of data points / batch size`.
235+
This is useful in batch-based stochastic gradient descent algorithms to be optimizing
236236
`log(prior) + log(likelihood of all the data points)` in the expectation.
237237
"""
238238
struct MiniBatchContext{Tctx,T} <: AbstractContext

src/model.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
"""
2-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
2+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext}
33
f::F
44
args::NamedTuple{argnames,Targs}
55
defaults::NamedTuple{defaultnames,Tdefaults}
6+
context::Ctx=DefaultContext()
67
end
78
89
A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
9-
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing
10-
arguments `missings`.
10+
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
11+
arguments `missings`, and evaluation context of type `Ctx`.
1112
1213
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
14+
`context` is by default `DefaultContext()`.
1315
1416
An argument with a type of `Missing` will be in `missings` by default. However, in
1517
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
@@ -1077,7 +1079,7 @@ end
10771079
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
10781080
10791081
# Examples
1080-
1082+
10811083
```jldoctest
10821084
julia> using MCMCChains, Distributions
10831085
@@ -1093,7 +1095,7 @@ julia> # construct a chain of samples using MCMCChains
10931095
chain = Chains(rand(10, 2, 3), [:s, :m]);
10941096
10951097
julia> logjoint(demo_model([1., 2.]), chain);
1096-
```
1098+
```
10971099
"""
10981100
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
10991101
var_info = VarInfo(model) # extract variables info from the model
@@ -1124,7 +1126,7 @@ end
11241126
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
11251127
11261128
# Examples
1127-
1129+
11281130
```jldoctest
11291131
julia> using MCMCChains, Distributions
11301132
@@ -1140,7 +1142,7 @@ julia> # construct a chain of samples using MCMCChains
11401142
chain = Chains(rand(10, 2, 3), [:s, :m]);
11411143
11421144
julia> logprior(demo_model([1., 2.]), chain);
1143-
```
1145+
```
11441146
"""
11451147
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
11461148
var_info = VarInfo(model) # extract variables info from the model
@@ -1171,7 +1173,7 @@ end
11711173
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
11721174
11731175
# Examples
1174-
1176+
11751177
```jldoctest
11761178
julia> using MCMCChains, Distributions
11771179
@@ -1187,7 +1189,7 @@ julia> # construct a chain of samples using MCMCChains
11871189
chain = Chains(rand(10, 2, 3), [:s, :m]);
11881190
11891191
julia> loglikelihood(demo_model([1., 2.]), chain);
1190-
```
1192+
```
11911193
"""
11921194
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
11931195
var_info = VarInfo(model) # extract variables info from the model

src/transforming.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""
2+
struct DynamicTransformationContext{isinverse} <: AbstractContext
3+
4+
When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to
5+
constrained space if `isinverse` or unconstrained if `!isinverse`.
6+
7+
Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the
8+
`DynamicTransformationContext` methods with more efficient implementations.
9+
`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know
10+
how to do the transformation, used by e.g. `SimpleVarInfo`.
11+
"""
112
struct DynamicTransformationContext{isinverse} <: AbstractContext end
213
NodeTrait(::DynamicTransformationContext) = IsLeaf()
314

0 commit comments

Comments
 (0)