@@ -100,121 +100,155 @@ def create_consumption_model(R=1.01, # Gross interest rate
100
100
A function that takes in parameters and returns parameters and grids
101
101
for the optimal savings problem.
102
102
"""
103
+ # Build grids and transition probabilities
103
104
w_grid = jnp.linspace(w_min, w_max, w_size)
104
105
mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
105
- y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P)
106
+ y_grid, Q = jnp.exp(mc.state_values), mc.P
107
+ # Pack and return
108
+ params = β, R, γ
106
109
sizes = w_size, y_size
107
- return (β, R, γ), sizes, (w_grid, y_grid, Q)
110
+ arrays = w_grid, y_grid, jnp.array(Q)
111
+ return params, sizes, arrays
108
112
```
109
113
110
114
Here's the right hand side of the Bellman equation:
111
115
112
116
``` {code-cell} ipython3
113
- def B (v, params, sizes, arrays ):
117
+ def _B (v, params, arrays, i, j, ip ):
114
118
"""
115
- A vectorized version of the right-hand side of the Bellman equation
116
- (before maximization), which is a 3D array representing
119
+ The right-hand side of the Bellman equation before maximization, which takes
120
+ the form
117
121
118
122
B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)
119
123
120
- for all (w, y, w′).
124
+ The indices are (i, j, ip) -> (w, y, w′).
121
125
"""
122
-
123
- # Unpack
124
126
β, R, γ = params
125
- w_size, y_size = sizes
126
127
w_grid, y_grid, Q = arrays
127
-
128
- # Compute current rewards r(w, y, wp) as array r[i, j, ip]
129
- w = jnp.reshape(w_grid, (w_size, 1, 1)) # w[i] -> w[i, j, ip]
130
- y = jnp.reshape(y_grid, (1, y_size, 1)) # z[j] -> z[i, j, ip]
131
- wp = jnp.reshape(w_grid, (1, 1, w_size)) # wp[ip] -> wp[i, j, ip]
128
+ w, y, wp = w_grid[i], y_grid[j], w_grid[ip]
132
129
c = R * w + y - wp
130
+ EV = jnp.sum(v[ip, :] * Q[j, :])
131
+ return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
132
+ ```
133
133
134
- # Calculate continuation rewards at all combinations of (w, y, wp)
135
- v = jnp.reshape(v, (1, 1, w_size, y_size)) # v[ip, jp] -> v[i, j, ip, jp]
136
- Q = jnp.reshape(Q, (1, y_size, 1, y_size)) # Q[j, jp] -> Q[i, j, ip, jp]
137
- EV = jnp.sum(v * Q, axis=3) # sum over last index jp
134
+ Now we successively apply ` vmap ` to vectorize $B$ by simulating nested loops.
138
135
139
- # Compute the right-hand side of the Bellman equation
140
- return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)
136
+ ``` {code-cell} ipython3
137
+ B_1 = jax.vmap(_B, in_axes=(None, None, None, None, None, 0))
138
+ B_2 = jax.vmap(B_1, in_axes=(None, None, None, None, 0, None))
139
+ B_vmap = jax.vmap(B_2, in_axes=(None, None, None, 0, None, None))
140
+ ```
141
+
142
+ Here's a fully vectorized version of $B$.
143
+
144
+ ``` {code-cell} ipython3
145
+ def B(v, params, sizes, arrays):
146
+ w_size, y_size = sizes
147
+ w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
148
+ return B_vmap(v, params, arrays, w_indices, y_indices, w_indices)
149
+
150
+ B = jax.jit(B, static_argnums=(2,))
141
151
```
142
152
143
153
## Operators
144
154
155
+
156
+ Here's the Bellman operator $T$
157
+
158
+ ``` {code-cell} ipython3
159
+ def T(v, params, sizes, arrays):
160
+ "The Bellman operator."
161
+ return jnp.max(B(v, params, sizes, arrays), axis=-1)
162
+
163
+ T = jax.jit(T, static_argnums=(2,))
164
+ ```
165
+
166
+ The next function computes a $v$-greedy policy given $v$
167
+
168
+ ``` {code-cell} ipython3
169
+ def get_greedy(v, params, sizes, arrays):
170
+ "Computes a v-greedy policy, returned as a set of indices."
171
+ return jnp.argmax(B(v, params, sizes, arrays), axis=-1)
172
+
173
+ get_greedy = jax.jit(get_greedy, static_argnums=(2,))
174
+
175
+ ```
176
+
145
177
We define a function to compute the current rewards $r_ \sigma$ given policy $\sigma$,
146
178
which is defined as the vector
147
179
148
180
$$
149
- r_\sigma(w, y) := r(w, y, \sigma(w, y))
181
+ r_\sigma(w, y) := r(w, y, \sigma(w, y))
150
182
$$
151
183
152
184
``` {code-cell} ipython3
153
- def compute_r_σ (σ, params, sizes, arrays ):
185
+ def _compute_r_σ (σ, params, arrays, i, j ):
154
186
"""
155
- Compute the array r_σ[i, j] = r[i, j, σ[i, j]], which gives current
156
- rewards given policy σ.
187
+ With indices (i, j) -> (w, y) and wp = σ[i, j], compute
188
+
189
+ r_σ[i, j] = u(Rw + y - wp)
190
+
191
+ which gives current rewards under policy σ.
157
192
"""
158
193
159
194
# Unpack model
160
195
β, R, γ = params
161
- w_size, y_size = sizes
162
196
w_grid, y_grid, Q = arrays
163
-
164
197
# Compute r_σ[i, j]
165
- w = jnp.reshape(w_grid, (w_size, 1))
166
- y = jnp.reshape(y_grid, (1, y_size))
167
- wp = w_grid[σ]
198
+ w, y, wp = w_grid[i], y_grid[j], w_grid[σ[i, j]]
168
199
c = R * w + y - wp
169
200
r_σ = c**(1-γ)/(1-γ)
170
201
171
202
return r_σ
172
203
```
173
204
174
- Now we define the policy operator $T _ \sigma$
205
+ Now we successively apply ` vmap ` to simulate nested loops.
175
206
176
207
``` {code-cell} ipython3
177
- def T_σ(v, σ, params, sizes, arrays):
178
- "The σ-policy operator."
208
+ r_1 = jax.vmap(_compute_r_σ, in_axes=(None, None, None, None, 0))
209
+ r_σ_vmap = jax.vmap(r_1, in_axes=(None, None, None, 0, None))
210
+ ```
179
211
180
- # Unpack model
181
- β, R, γ = params
212
+ Here's a fully vectorized version of $r_ \sigma$.
213
+
214
+ ``` {code-cell} ipython3
215
+ def compute_r_σ(σ, params, sizes, arrays):
182
216
w_size, y_size = sizes
183
- w_grid, y_grid, Q = arrays
217
+ w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
218
+ return r_σ_vmap(σ, params, arrays, w_indices, y_indices)
184
219
185
- r_σ = compute_r_σ(σ, params, sizes, arrays)
220
+ compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
221
+ ```
222
+
223
+ Now we define the policy operator $T_ \sigma$ going through similar steps
186
224
187
- # Compute the array v[σ[i, j], jp]
188
- yp_idx = jnp.arange(y_size)
189
- yp_idx = jnp.reshape(yp_idx, (1, 1, y_size))
190
- σ = jnp.reshape(σ, (w_size, y_size, 1))
191
- V = v[σ, yp_idx]
225
+ ``` {code-cell} ipython3
226
+ def _T_σ(v, σ, params, arrays, i, j):
227
+ "The σ-policy operator."
192
228
193
- # Convert Q[j, jp] to Q[i, j, jp]
194
- Q = jnp.reshape(Q, (1, y_size, y_size))
229
+ # Unpack model
230
+ β, R, γ = params
231
+ w_grid, y_grid, Q = arrays
195
232
233
+ r_σ = _compute_r_σ(σ, params, arrays, i, j)
196
234
# Calculate the expected sum Σ_jp v[σ[i, j], jp] * Q[i, j, jp]
197
- EV = jnp.sum(V * Q, axis=2 )
235
+ EV = jnp.sum(v[σ[i, j], :] * Q[j, :] )
198
236
199
237
return r_σ + β * EV
200
- ```
201
238
202
- and the Bellman operator $T$
203
239
204
- ``` {code-cell} ipython3
205
- def T(v, params, sizes, arrays):
206
- "The Bellman operator."
207
- return jnp.max(B(v, params, sizes, arrays), axis=2)
208
- ```
240
+ T_1 = jax.vmap(_T_σ, in_axes=(None, None, None, None, None, 0))
241
+ T_σ_vmap = jax.vmap(T_1, in_axes=(None, None, None, None, 0, None))
209
242
210
- The next function computes a $v$-greedy policy given $v$
243
+ def T_σ(v, σ, params, sizes, arrays):
244
+ w_size, y_size = sizes
245
+ w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
246
+ return T_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
211
247
212
- ``` {code-cell} ipython3
213
- def get_greedy(v, params, sizes, arrays):
214
- "Computes a v-greedy policy, returned as a set of indices."
215
- return jnp.argmax(B(v, params, sizes, arrays), axis=2)
248
+ T_σ = jax.jit(T_σ, static_argnums=(3,))
216
249
```
217
250
251
+
218
252
The function below computes the value $v_ \sigma$ of following policy $\sigma$.
219
253
220
254
This lifetime value is a function $v_ \sigma$ that satisfies
@@ -248,29 +282,28 @@ JAX allows us to solve linear systems defined in terms of operators; the first
248
282
step is to define the function $L_ {\sigma}$.
249
283
250
284
``` {code-cell} ipython3
251
- def L_σ (v, σ, params, sizes, arrays ):
285
+ def _L_σ (v, σ, params, arrays, i, j ):
252
286
"""
253
287
Here we set up the linear map v -> L_σ v, where
254
288
255
289
(L_σ v)(w, y) = v(w, y) - β Σ_y′ v(σ(w, y), y′) Q(y, y′)
256
290
257
291
"""
258
-
292
+ # Unpack
259
293
β, R, γ = params
260
- w_size, y_size = sizes
261
294
w_grid, y_grid, Q = arrays
295
+ # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
296
+ return v[i, j] - β * jnp.sum(v[σ[i, j], :] * Q[j, :])
262
297
263
- # Set up the array v[σ[i, j], jp]
264
- zp_idx = jnp.arange(y_size)
265
- zp_idx = jnp.reshape(zp_idx, (1, 1, y_size))
266
- σ = jnp.reshape(σ, (w_size, y_size, 1))
267
- V = v[σ, zp_idx]
298
+ L_1 = jax.vmap(_L_σ, in_axes=(None, None, None, None, None, 0))
299
+ L_σ_vmap = jax.vmap(L_1, in_axes=(None, None, None, None, 0, None))
268
300
269
- # Expand Q[j, jp] to Q[i, j, jp]
270
- Q = jnp.reshape(Q, (1, y_size, y_size))
301
+ def L_σ(v, σ, params, sizes, arrays):
302
+ w_size, y_size = sizes
303
+ w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
304
+ return L_σ_vmap(v, σ, params, arrays, w_indices, y_indices)
271
305
272
- # Compute and return v[i, j] - β Σ_jp v[σ[i, j], jp] * Q[j, jp]
273
- return v - β * jnp.sum(V * Q, axis=2)
306
+ L_σ = jax.jit(L_σ, static_argnums=(3,))
274
307
```
275
308
276
309
Now we can define a function to compute $v_ {\sigma}$
@@ -290,20 +323,16 @@ def get_value(σ, params, sizes, arrays):
290
323
partial_L_σ = lambda v: L_σ(v, σ, params, sizes, arrays)
291
324
292
325
return jax.scipy.sparse.linalg.bicgstab(partial_L_σ, r_σ)[0]
293
- ```
294
-
295
- ## JIT compiled versions
296
326
297
- ``` {code-cell} ipython3
298
- B = jax.jit(B, static_argnums=(2,))
299
- compute_r_σ = jax.jit(compute_r_σ, static_argnums=(2,))
300
- T = jax.jit(T, static_argnums=(2,))
301
- get_greedy = jax.jit(get_greedy, static_argnums=(2,))
302
327
get_value = jax.jit(get_value, static_argnums=(2,))
303
- T_σ = jax.jit(T_σ, static_argnums=(3,))
304
- L_σ = jax.jit(L_σ, static_argnums=(3,))
328
+
305
329
```
306
330
331
+
332
+
333
+ ## Iteration
334
+
335
+
307
336
We use successive approximation for VFI.
308
337
309
338
``` {code-cell} ipython3
0 commit comments