Skip to content

Commit ef43c79

Browse files
authored
Use vmap throughout opt savings 2 (#155)
* misc * misc * misc
1 parent ce4ef38 commit ef43c79

File tree

2 files changed

+112
-79
lines changed

2 files changed

+112
-79
lines changed

lectures/opt_savings_1.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,15 @@ def create_consumption_model(R=1.01, # Gross interest rate
120120
A function that takes in parameters and returns parameters and grids
121121
for the optimal savings problem.
122122
"""
123+
# Build grids and transition probabilities
123124
w_grid = np.linspace(w_min, w_max, w_size)
124125
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
125126
y_grid, Q = np.exp(mc.state_values), mc.P
127+
# Pack and return
128+
params = β, R, γ
126129
sizes = w_size, y_size
127-
return (β, R, γ), sizes, (w_grid, y_grid, Q)
130+
arrays = w_grid, y_grid, Q
131+
return params, sizes, arrays
128132
```
129133

130134
(The function returns sizes of arrays because we use them later to help

lectures/opt_savings_2.md

Lines changed: 107 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -100,121 +100,155 @@ def create_consumption_model(R=1.01, # Gross interest rate
100100
A function that takes in parameters and returns parameters and grids
101101
for the optimal savings problem.
102102
"""
103+
# Build grids and transition probabilities
103104
w_grid = jnp.linspace(w_min, w_max, w_size)
104105
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
105-
y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P)
106+
y_grid, Q = jnp.exp(mc.state_values), mc.P
107+
# Pack and return
108+
params = β, R, γ
106109
sizes = w_size, y_size
107-
return (β, R, γ), sizes, (w_grid, y_grid, Q)
110+
arrays = w_grid, y_grid, jnp.array(Q)
111+
return params, sizes, arrays
108112
```
109113

110114
Here's the right hand side of the Bellman equation:
111115

112116
```{code-cell} ipython3
113-
def B(v, params, sizes, arrays):
117+
def _B(v, params, arrays, i, j, ip):
114118
"""
115-
A vectorized version of the right-hand side of the Bellman equation
116-
(before maximization), which is a 3D array representing
119+
The right-hand side of the Bellman equation before maximization, which takes
120+
the form
117121
118122
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
119123
120-
for all (w, y, w′).
124+
The indices are (i, j, ip) -> (w, y, w′).
121125
"""
122-
123-
# Unpack
124126
β, R, γ = params
125-
w_size, y_size = sizes
126127
w_grid, y_grid, Q = arrays
127-
128-
# Compute current rewards r(w, y, wp) as array r[i, j, ip]
129-
w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
130-
y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
131-
wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
128+
w, y, wp = w_grid[i], y_grid[j], w_grid[ip]
132129
c = R * w + y - wp
130+
EV = jnp.sum(v[ip, :] * Q[j, :])
131+
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
132+
```
133133

134-
# Calculate continuation rewards at all combinations of (w, y, wp)
135-
v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
136-
Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
137-
EV = jnp.sum(v * Q, axis=3) # sum over last index jp
134+
Now we successively apply `vmap` to vectorize $B$ by simulating nested loops.
138135

139-
# Compute the right-hand side of the Bellman equation
140-
return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
136+
```{code-cell} ipython3
137+
B_1 = jax.vmap(_B, in_axes=(None, None, None, None, None, 0))
138+
B_2 = jax.vmap(B_1, in_axes=(None, None, None, None, 0, None))
139+
B_vmap = jax.vmap(B_2, in_axes=(None, None, None, 0, None, None))
140+
```
141+
142+
Here's a fully vectorized version of $B$.
143+
144+
```{code-cell} ipython3
145+
def B(v, params, sizes, arrays):
146+
w_size, y_size = sizes
147+
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
148+
return B_vmap(v, params, arrays, w_indices, y_indices, w_indices)
149+
150+
B = jax.jit(B, static_argnums=(2,))
141151
```
142152

143153
## Operators
144154

155+
156+
Here's the Bellman operator $T$
157+
158+
```{code-cell} ipython3
159+
def T(v, params, sizes, arrays):
160+
"The Bellman operator."
161+
return jnp.max(B(v, params, sizes, arrays), axis=-1)
162+
163+
T = jax.jit(T, static_argnums=(2,))
164+
```
165+
166+
The next function computes a $v$-greedy policy given $v$
167+
168+
```{code-cell} ipython3
169+
def get_greedy(v, params, sizes, arrays):
170+
"Computes a v-greedy policy, returned as a set of indices."
171+
return jnp.argmax(B(v, params, sizes, arrays), axis=-1)
172+
173+
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
174+
175+
```
176+
145177
We define a function to compute the current rewards $r_\sigma$ given policy $\sigma$,
146178
which is defined as the vector
147179

148180
$$
149-
r_\sigma(w, y) := r(w, y, \sigma(w, y))
181+
r_\sigma(w, y) := r(w, y, \sigma(w, y))
150182
$$
151183

152184
```{code-cell} ipython3
153-
def compute_r_σ(σ, params, sizes, arrays):
185+
def _compute_r_σ(σ, params, arrays, i, j):
154186
"""
155-
Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current
156-
rewards given policy σ.
187+
With indices (i, j) -> (w, y) and wp = σ[i, j], compute
188+
189+
r_σ[i, j] = u(Rw + y - wp)
190+
191+
which gives current rewards under policy σ.
157192
"""
158193
159194
# Unpack model
160195
β, R, γ = params
161-
w_size, y_size = sizes
162196
w_grid, y_grid, Q = arrays
163-
164197
# Compute r_σ[i, j]
165-
w = jnp.reshape(w_grid, (w_size, 1))
166-
y = jnp.reshape(y_grid, (1, y_size))
167-
wp = w_grid[σ]
198+
w, y, wp = w_grid[i], y_grid[j], w_grid[σ[i, j]]
168199
c = R * w + y - wp
169200
r_σ = c**(1-γ)/(1-γ)
170201
171202
return r_σ
172203
```
173204

174-
Now we define the policy operator $T_\sigma$
205+
Now we successively apply `vmap` to simulate nested loops.
175206

176207
```{code-cell} ipython3
177-
def T_σ(v, σ, params, sizes, arrays):
178-
"The σ-policy operator."
208+
r_1 = jax.vmap(_compute_r_σ, in_axes=(None, None, None, None, 0))
209+
r_σ_vmap = jax.vmap(r_1, in_axes=(None, None, None, 0, None))
210+
```
179211

180-
# Unpack model
181-
β, R, γ = params
212+
Here's a fully vectorized version of $r_\sigma$.
213+
214+
```{code-cell} ipython3
215+
def compute_r_σ(σ, params, sizes, arrays):
182216
w_size, y_size = sizes
183-
w_grid, y_grid, Q = arrays
217+
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
218+
return r_σ_vmap(σ, params, arrays, w_indices, y_indices)
184219
185-
r_σ = compute_r_σ(σ, params, sizes, arrays)
220+
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
221+
```
222+
223+
Now we define the policy operator $T_\sigma$ going through similar steps
186224

187-
# Compute the array v[σ[i, j], jp]
188-
yp_idx = jnp.arange(y_size)
189-
yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
190-
σ = jnp.reshape(σ, (w_size, y_size, 1))
191-
V = v[σ, yp_idx]
225+
```{code-cell} ipython3
226+
def _T_σ(v, σ, params, arrays, i, j):
227+
"The σ-policy operator."
192228
193-
# Convert Q[j, jp] to Q[i, j, jp]
194-
Q = jnp.reshape(Q, (1, y_size, y_size))
229+
# Unpack model
230+
β, R, γ = params
231+
w_grid, y_grid, Q = arrays
195232
233+
r_σ = _compute_r_σ(σ, params, arrays, i, j)
196234
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
197-
EV = jnp.sum(V * Q, axis=2)
235+
EV = jnp.sum(v[σ[i, j], :] * Q[j, :])
198236
199237
return r_σ + β * EV
200-
```
201238
202-
and the Bellman operator $T$
203239
204-
```{code-cell} ipython3
205-
def T(v, params, sizes, arrays):
206-
"The Bellman operator."
207-
return jnp.max(B(v, params, sizes, arrays), axis=2)
208-
```
240+
T_1 = jax.vmap(_T_σ, in_axes=(None, None, None, None, None, 0))
241+
T_σ_vmap = jax.vmap(T_1, in_axes=(None, None, None, None, 0, None))
209242
210-
The next function computes a $v$-greedy policy given $v$
243+
def T_σ(v, σ, params, sizes, arrays):
244+
w_size, y_size = sizes
245+
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
246+
return T_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
211247
212-
```{code-cell} ipython3
213-
def get_greedy(v, params, sizes, arrays):
214-
"Computes a v-greedy policy, returned as a set of indices."
215-
return jnp.argmax(B(v, params, sizes, arrays), axis=2)
248+
T_σ = jax.jit(T_σ, static_argnums=(3,))
216249
```
217250

251+
218252
The function below computes the value $v_\sigma$ of following policy $\sigma$.
219253

220254
This lifetime value is a function $v_\sigma$ that satisfies
@@ -248,29 +282,28 @@ JAX allows us to solve linear systems defined in terms of operators; the first
248282
step is to define the function $L_{\sigma}$.
249283

250284
```{code-cell} ipython3
251-
def L_σ(v, σ, params, sizes, arrays):
285+
def _L_σ(v, σ, params, arrays, i, j):
252286
"""
253287
Here we set up the linear map v -> L_σ v, where
254288
255289
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
256290
257291
"""
258-
292+
# Unpack
259293
β, R, γ = params
260-
w_size, y_size = sizes
261294
w_grid, y_grid, Q = arrays
295+
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
296+
return v[i, j] - β * jnp.sum(v[σ[i, j], :] * Q[j, :])
262297
263-
# Set up the array v[σ[i, j], jp]
264-
zp_idx = jnp.arange(y_size)
265-
zp_idx = jnp.reshape(zp_idx, (1, 1, y_size))
266-
σ = jnp.reshape(σ, (w_size, y_size, 1))
267-
V = v[σ, zp_idx]
298+
L_1 = jax.vmap(_L_σ, in_axes=(None, None, None, None, None, 0))
299+
L_σ_vmap = jax.vmap(L_1, in_axes=(None, None, None, None, 0, None))
268300
269-
# Expand Q[j, jp] to Q[i, j, jp]
270-
Q = jnp.reshape(Q, (1, y_size, y_size))
301+
def L_σ(v, σ, params, sizes, arrays):
302+
w_size, y_size = sizes
303+
w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
304+
return L_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
271305
272-
# Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
273-
return v - β * jnp.sum(V * Q, axis=2)
306+
L_σ = jax.jit(L_σ, static_argnums=(3,))
274307
```
275308

276309
Now we can define a function to compute $v_{\sigma}$
@@ -290,20 +323,16 @@ def get_value(σ, params, sizes, arrays):
290323
partial_L_σ = lambda v: L_σ(v, σ, params, sizes, arrays)
291324
292325
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
293-
```
294-
295-
## JIT compiled versions
296326
297-
```{code-cell} ipython3
298-
B = jax.jit(B, static_argnums=(2,))
299-
compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
300-
T = jax.jit(T, static_argnums=(2,))
301-
get_greedy = jax.jit(get_greedy, static_argnums=(2,))
302327
get_value = jax.jit(get_value, static_argnums=(2,))
303-
T_σ = jax.jit(T_σ, static_argnums=(3,))
304-
L_σ = jax.jit(L_σ, static_argnums=(3,))
328+
305329
```
306330

331+
332+
333+
## Iteration
334+
335+
307336
We use successive approximation for VFI.
308337

309338
```{code-cell} ipython3

0 commit comments

Comments
 (0)