Skip to content

Commit 7297b36

Browse files
authored
MInor docstring fix (#612)
1 parent 73a6c0c commit 7297b36

File tree

5 files changed

+71
-55
lines changed

5 files changed

+71
-55
lines changed

blackjax/adaptation/mclmc_adaptation.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L.
15-
16-
"""
14+
"""Algorithms to adapt the MCLMC kernel parameters, namely step size and L."""
1715

1816
from typing import NamedTuple
1917

2018
import jax
2119
import jax.numpy as jnp
2220
from jax.flatten_util import ravel_pytree
2321

24-
from blackjax.diagnostics import effective_sample_size # type: ignore
22+
from blackjax.diagnostics import effective_sample_size
2523
from blackjax.util import pytree_size
2624

2725

2826
class MCLMCAdaptationState(NamedTuple):
2927
"""Represents the tunable parameters for MCLMC adaptation.
3028
31-
Attributes:
32-
L (float): The momentum decoherent rate for the MCLMC algorithm.
33-
step_size (float): The step size used for the MCLMC algorithm.
29+
L
30+
The momentum decoherent rate for the MCLMC algorithm.
31+
step_size
32+
The step size used for the MCLMC algorithm.
3433
"""
3534

3635
L: float
@@ -52,25 +51,39 @@ def mclmc_find_L_and_step_size(
5251
"""
5352
Finds the optimal value of the parameters for the MCLMC algorithm.
5453
55-
Args:
56-
mclmc_kernel (callable): The kernel function used for the MCMC algorithm.
57-
num_steps (int): The number of MCMC steps that will subsequently be run, after tuning.
58-
state (MCMCState): The initial state of the MCMC algorithm.
59-
rng_key (jax.random.PRNGKey): The random number generator key.
60-
frac_tune1 (float): The fraction of tuning for the first step of the adaptation.
61-
frac_tune2 (float): The fraction of tuning for the second step of the adaptation.
62-
frac_tune3 (float): The fraction of tuning for the third step of the adaptation.
63-
desired_energy_var (float): The desired energy variance for the MCMC algorithm.
64-
trust_in_estimate (float): The trust in the estimate of optimal stepsize.
65-
num_effective_samples (int): The number of effective samples for the MCMC algorithm.
66-
67-
Returns:
68-
tuple: A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
69-
70-
Raises:
71-
None
72-
73-
Examples:
54+
Parameters
55+
----------
56+
mclmc_kernel
57+
The kernel function used for the MCMC algorithm.
58+
num_steps
59+
The number of MCMC steps that will subsequently be run, after tuning.
60+
state
61+
The initial state of the MCMC algorithm.
62+
rng_key
63+
The random number generator key.
64+
frac_tune1
65+
The fraction of tuning for the first step of the adaptation.
66+
frac_tune2
67+
The fraction of tuning for the second step of the adaptation.
68+
frac_tune3
69+
The fraction of tuning for the third step of the adaptation.
70+
desired_energy_va
71+
The desired energy variance for the MCMC algorithm.
72+
trust_in_estimate
73+
The trust in the estimate of optimal stepsize.
74+
num_effective_samples
75+
The number of effective samples for the MCMC algorithm.
76+
77+
Returns
78+
-------
79+
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
80+
81+
82+
Examples
83+
-------
84+
85+
.. code::
86+
7487
# Define the kernel function
7588
def kernel(x):
7689
return x ** 2
@@ -265,7 +278,8 @@ def adaptation_L(state, params, num_steps, key):
265278

266279

267280
def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
268-
"""if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case."""
281+
"""if there are nans, let's reduce the stepsize, and not update the state. The
282+
function returns the old state in this case."""
269283

270284
reduced_step_size = 0.8
271285
p, unravel_fn = ravel_pytree(next_state.position)

blackjax/mcmc/mclmc.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ class MCLMCInfo(NamedTuple):
3131
"""
3232
Additional information on the MCLMC transition.
3333
34-
Attributes
35-
----------
36-
transformed_position :
34+
transformed_position
3735
The value of the samples after a transformation. This is typically a projection onto a lower dimensional subspace.
38-
logdensity :
36+
logdensity
3937
The log-density of the distribution at the current step of the MCLMC chain.
40-
energy_change :
38+
kinetic_change
39+
The difference in kinetic energy between the current and previous step.
40+
energy_change
4141
The difference in energy between the current and previous step.
4242
"""
4343

@@ -68,9 +68,9 @@ def build_kernel(logdensity_fn, integrator, transform):
6868
transform
6969
Value of the difference in energy above which we consider that the transition is divergent.
7070
L
71-
the momentum decoherence rate
71+
the momentum decoherence rate.
7272
step_size
73-
step size of the integrator
73+
step size of the integrator.
7474
7575
Returns
7676
-------
@@ -136,8 +136,8 @@ class mclmc:
136136
137137
.. code::
138138
139-
step = jax.jit(mclmc.step)
140-
new_state, info = step(rng_key, state)
139+
step = jax.jit(mclmc.step)
140+
new_state, info = step(rng_key, state)
141141
142142
Parameters
143143
----------
@@ -146,11 +146,11 @@ class mclmc:
146146
transform
147147
A function to perform on the samples drawn from the target distribution
148148
L
149-
the momentum decoherence rate
149+
the momentum decoherence rate
150150
step_size
151-
step size of the integrator
151+
step size of the integrator
152152
integrator
153-
an integrator. We recommend using the default here.
153+
an integrator. We recommend using the default here.
154154
155155
Returns
156156
-------
@@ -185,13 +185,13 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L):
185185
186186
Parameters
187187
----------
188-
rng_key:
188+
rng_key
189189
The pseudo-random number generator key used to generate random numbers.
190-
momentum:
190+
momentum
191191
PyTree that the structure the output should to match.
192-
step_size:
192+
step_size
193193
Step size
194-
L:
194+
L
195195
controls rate of momentum change
196196
197197
Returns

blackjax/smc/base.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,18 @@ class SMCState(NamedTuple):
2424
2525
Particles must be a ArrayTree, each leave represents a variable from the posterior,
2626
being an array of size `(n_particles, ...)`.
27+
2728
Examples (three particles):
28-
- Single univariate posterior:
29-
[ Array([[1.], [1.2], [3.4]]) ]
30-
- Single bivariate posterior:
31-
[Array([[1,2], [3,4], [5,6]])]
32-
- Two variables, each univariate:
33-
[ Array([[1.], [1.2], [3.4]]),
34-
Array([[50.], [51], [55]]) ]
35-
- Two variables, first one bivariate, second one 4-variate:
36-
[ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]),
37-
Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]])]
29+
- Single univariate posterior:
30+
[ Array([[1.], [1.2], [3.4]]) ]
31+
- Single bivariate posterior:
32+
[ Array([[1,2], [3,4], [5,6]]) ]
33+
- Two variables, each univariate:
34+
[ Array([[1.], [1.2], [3.4]]),
35+
Array([[50.], [51], [55]]) ]
36+
- Two variables, first one bivariate, second one 4-variate:
37+
[ Array([[1., 2.], [1.2, 0.5], [3.4, 50]]),
38+
Array([[50., 51., 52., 51], [51., 52., 52. ,54.], [55., 60, 60, 70]]) ]
3839
"""
3940

4041
particles: ArrayTree

blackjax/smc/tuning/from_particles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def mass_matrix_from_particles(particles) -> Array:
3333
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
3434
Computing a mass matrix to be used in HMC from particles.
3535
Given the particles covariance matrix, set all non-diagonal elements as zero,
36-
take the inverse, and keep the diagonal.
36+
take the inverse, and keep the diagonal.
37+
3738
Returns
3839
-------
3940
A mass Matrix

requirements-doc.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jax>=0.4.16
88
jaxlib>=0.4.16
99
jaxopt
1010
jupytext
11-
myst_nb>=1.0.0rc0
11+
myst_nb>=1.0.0
1212
numba
1313
numpyro
1414
optax

0 commit comments

Comments
 (0)