Skip to content

Kesten edits #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Upgrade CUDANN
shell: bash -l {0}
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get -y install cudnn-cuda-12
- name: Setup Anaconda
uses: conda-incubator/setup-miniconda@v3
with:
Expand Down
61 changes: 37 additions & 24 deletions lectures/kesten_processes.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ In addition to JAX and Anaconda, this lecture will need the following libraries:
This lecture describes Kesten processes, which are an important class of
stochastic processes, and an application of firm dynamics.

The lecture draws on [an earlier QuantEcon
lecture](https://python.quantecon.org/kesten_processes.html), which uses Numba
to accelerate the computations.
The lecture draws on [an earlier QuantEcon lecture](https://python.quantecon.org/kesten_processes.html),
which uses Numba to accelerate the computations.

In that earlier lecture you can find a more detailed discussion of the concepts involved.

Expand Down Expand Up @@ -137,10 +136,8 @@ We now study the implications of this specification.

#### Heavy tails

If the conditions of the [Kesten--Goldie
Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)
are satisfied, then {eq}`firm_dynam` implies that the firm size distribution
will have Pareto tails.
If the conditions of the [Kesten--Goldie Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem)
are satisfied, then {eq}`firm_dynam` implies that the firm size distribution will have Pareto tails.

This matches empirical findings across many data sets.

Expand Down Expand Up @@ -190,12 +187,11 @@ class Firm(NamedTuple):
μ_e: float = 0.0
σ_e: float = 0.5
s_bar: float = 1.0

#
# Here's code to update a cross-section of firms according to the dynamics in
# [](firm_dynam_ee).
```

Here's code to update a cross-section of firms according to the dynamics in
[](firm_dynam_ee).

```{code-cell} ipython3
@jax.jit
def update_cross_section(s, a, b, e, firm):
Expand Down Expand Up @@ -250,7 +246,6 @@ data = generate_cross_section(firm).block_until_ready()
toc()
```


Let's produce the rank-size plot and check the distribution:

```{code-cell} ipython3
Expand All @@ -271,7 +266,7 @@ The plot produces a straight line, consistent with a Pareto tail.

We did not JIT-compile the `for` loop above because
acceleration of outer loops makes relatively little difference terms of
compute time.
compute time.

However, to maximize performance, let's try squeezing out a bit more speed
by replacing the `for` loop with
Expand Down Expand Up @@ -311,10 +306,10 @@ def generate_cross_section_lax(
0, T, update_cross_section, initial_state
)
return final_s

# Let's see if we got any speed gain
```

Let's see if we get any speed gain

```{code-cell} ipython3
tic()
data = generate_cross_section_lax(firm).block_until_ready()
Expand All @@ -339,14 +334,27 @@ ax.set_ylabel("log size")

plt.show()

#
# If the time horizon is not too large, we can also try generating all shocks at
# once.
#
# Note, however, that this approach consumes more memory, as we need to have to
# store large matrices of random draws
#
# Hence the code below will fail due to out-of-memory errors when `T` and `M` are large.
```

## Exercises
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jstac I might update this to be in an exercise and solution directive

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah -- just read the comments thread. Will do.


```{exercise-start}
:label: kp_ex1
```

Try writing an alternative version of `generate_cross_section_lax()` where the entire sequence of random draws is generated at once, so that all of `a`, `b`, and `e` are of shape `(T, M)`.

(The `update_cross_section()` function should not generate any random numbers.)

Does it improve the runtime?

What are the pros and cons of this approach.

```{exercise-end}
```

```{solution-start} kp_ex1
:class: dropdown
```

```{code-cell} ipython3
Expand Down Expand Up @@ -393,6 +401,11 @@ data = generate_cross_section_lax(firm).block_until_ready()
toc()
```

This second method might be slightly faster in some cases but in general the
This method might be faster in some cases but in general the
relative speed will depend on the size of the cross-section and the length of
the simulation paths.

Also, this method is far more memory intensive.

```{solution-end}
```
Loading