Skip to content

Kesten fixes part 2 #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions lectures/kesten_processes.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ time `T`) corresponds to firm size distribution in (approximate) equilibrium.

```{code-cell} ipython3
def generate_cross_section(
firm, M=1_000_000, T=500, s_init=1.0, seed=123
firm, M=500_000, T=500, s_init=1.0, seed=123
):

μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm
Expand All @@ -231,14 +231,16 @@ def generate_cross_section(
return s
```

Let's try running the code and generating a cross-section.

```{code-cell} ipython3
firm = Firm()
tic()
data = generate_cross_section(firm).block_until_ready()
toc()
```

Running the above function again so we can see the speed with and without compile time.
We run the function again so we can see the speed without compile time.

```{code-cell} ipython3
tic()
Expand All @@ -264,13 +266,14 @@ The plot produces a straight line, consistent with a Pareto tail.

#### Alternative implementation with `lax.fori_loop`

We did not JIT-compile the `for` loop above because
acceleration of outer loops makes relatively little difference terms of
compute time.
Although we JIT-compiled some of the code above,
we did not JIT-compile the `for` loop.

Let's try squeezing out a bit more speed
by

However, to maximize performance, let's try squeezing out a bit more speed
by replacing the `for` loop with
[`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html).
* replacing the `for` loop with [`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and
* JIT-compiling the whole function.

Here a the `lax.fori_loop` version:

Expand Down Expand Up @@ -348,7 +351,7 @@ Try writing an alternative version of `generate_cross_section_lax()` where the e

Does it improve the runtime?

What are the pros and cons of this approach.
What are the pros and cons of this approach?

```{exercise-end}
```
Expand Down Expand Up @@ -401,11 +404,14 @@ data = generate_cross_section_lax(firm).block_until_ready()
toc()
```

This method might be faster in some cases but in general the
relative speed will depend on the size of the cross-section and the length of
This method might or might not be faster.

In general, the relative speed will depend on the size of the cross-section and the length of
the simulation paths.

Also, this method is far more memory intensive.
However, this method is far more memory intensive.

It will fail when $T$ and $M$ become sufficiently large.

```{solution-end}
```
```
Loading