|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | + jupytext_version: 1.16.1 |
| 8 | +kernelspec: |
| 9 | + display_name: Python 3 (ipykernel) |
| 10 | + language: python |
| 11 | + name: python3 |
| 12 | +--- |
| 13 | + |
| 14 | +# Job Search |
| 15 | + |
| 16 | +```{include} _admonition/gpu.md |
| 17 | +``` |
| 18 | + |
| 19 | + |
| 20 | +In this lecture we study a basic infinite-horizon job search with Markov wage |
| 21 | +draws |
| 22 | + |
| 23 | +The exercise at the end asks you to add recursive preferences and compare |
| 24 | +the result. |
| 25 | + |
| 26 | +In addition to what’s in Anaconda, this lecture will need the following libraries: |
| 27 | + |
| 28 | +```{code-cell} ipython3 |
| 29 | +:tags: [hide-output] |
| 30 | +
|
| 31 | +!pip install quantecon |
| 32 | +``` |
| 33 | + |
| 34 | +We use the following imports. |
| 35 | + |
| 36 | +```{code-cell} ipython3 |
| 37 | +import matplotlib.pyplot as plt |
| 38 | +import quantecon as qe |
| 39 | +import jax |
| 40 | +import jax.numpy as jnp |
| 41 | +from collections import namedtuple |
| 42 | +
|
| 43 | +jax.config.update("jax_enable_x64", True) |
| 44 | +``` |
| 45 | + |
| 46 | +## Model |
| 47 | + |
| 48 | +We study an elementary model where |
| 49 | + |
| 50 | +* jobs are permanent |
| 51 | +* unemployed workers receive current compensation $c$ |
| 52 | +* the wage offer distribution $\{W_t\}$ is Markovian |
| 53 | +* the horizon is infinite |
| 54 | +* an unemployment agent discounts the future via discount factor $\beta \in (0,1)$ |
| 55 | + |
| 56 | +The wage process obeys |
| 57 | + |
| 58 | +$$ |
| 59 | + W_{t+1} = \rho W_t + \nu Z_{t+1}, |
| 60 | + \qquad \{Z_t\} \text{ is IID and } N(0, 1) |
| 61 | +$$ |
| 62 | + |
| 63 | +We discretize this using Tauchen's method to produce a stochastic matrix $P$ |
| 64 | + |
| 65 | +Since jobs are permanent, the return to accepting wage offer $w$ today is |
| 66 | + |
| 67 | +$$ |
| 68 | + w + \beta w + \beta^2 w + \cdots = \frac{w}{1-\beta} |
| 69 | +$$ |
| 70 | + |
| 71 | +The Bellman equation is |
| 72 | + |
| 73 | +$$ |
| 74 | + v(w) = \max |
| 75 | + \left\{ |
| 76 | + \frac{w}{1-\beta}, c + \beta \sum_{w'} v(w') P(w, w') |
| 77 | + \right\} |
| 78 | +$$ |
| 79 | + |
| 80 | +We solve this model using value function iteration. |
| 81 | + |
| 82 | + |
| 83 | +Let's set up a `namedtuple` to store information needed to solve the model. |
| 84 | + |
| 85 | +```{code-cell} ipython3 |
| 86 | +Model = namedtuple('Model', ('n', 'w_vals', 'P', 'β', 'c', 'θ')) |
| 87 | +``` |
| 88 | + |
| 89 | +The function below holds default values and populates the namedtuple. |
| 90 | + |
| 91 | +```{code-cell} ipython3 |
| 92 | +def create_js_model( |
| 93 | + n=500, # wage grid size |
| 94 | + ρ=0.9, # wage persistence |
| 95 | + ν=0.2, # wage volatility |
| 96 | + β=0.99, # discount factor |
| 97 | + c=1.0, # unemployment compensation |
| 98 | + θ=-0.1 # risk parameter |
| 99 | + ): |
| 100 | + "Creates an instance of the job search model with Markov wages." |
| 101 | + mc = qe.tauchen(n, ρ, ν) |
| 102 | + w_vals, P = jnp.exp(mc.state_values), mc.P |
| 103 | + P = jnp.array(P) |
| 104 | + return Model(n, w_vals, P, β, c, θ) |
| 105 | +``` |
| 106 | + |
| 107 | +Here's the Bellman operator. |
| 108 | + |
| 109 | +```{code-cell} ipython3 |
| 110 | +@jax.jit |
| 111 | +def T(v, model): |
| 112 | + """ |
| 113 | + The Bellman operator Tv = max{e, c + β E v} with |
| 114 | +
|
| 115 | + e(w) = w / (1-β) and (Ev)(w) = E_w[ v(W')] |
| 116 | +
|
| 117 | + """ |
| 118 | + n, w_vals, P, β, c, θ = model |
| 119 | + h = c + β * P @ v |
| 120 | + e = w_vals / (1 - β) |
| 121 | +
|
| 122 | + return jnp.maximum(e, h) |
| 123 | +``` |
| 124 | + |
| 125 | +The next function computes the optimal policy under the assumption that $v$ is |
| 126 | + the value function. |
| 127 | + |
| 128 | +The policy takes the form |
| 129 | + |
| 130 | +$$ |
| 131 | + \sigma(w) = \mathbf 1 |
| 132 | + \left\{ |
| 133 | + \frac{w}{1-\beta} \geq c + \beta \sum_{w'} v(w') P(w, w') |
| 134 | + \right\} |
| 135 | +$$ |
| 136 | + |
| 137 | +Here $\mathbf 1$ is an indicator function. |
| 138 | + |
| 139 | +The statement above means that the worker accepts ($\sigma(w) = 1$) when the value of stopping |
| 140 | +is higher than the value of continuing. |
| 141 | + |
| 142 | +```{code-cell} ipython3 |
| 143 | +@jax.jit |
| 144 | +def get_greedy(v, model): |
| 145 | + """Get a v-greedy policy.""" |
| 146 | + n, w_vals, P, β, c, θ = model |
| 147 | + e = w_vals / (1 - β) |
| 148 | + h = c + β * P @ v |
| 149 | + σ = jnp.where(e >= h, 1, 0) |
| 150 | + return σ |
| 151 | +``` |
| 152 | + |
| 153 | +Here's a routine for value function iteration. |
| 154 | + |
| 155 | +```{code-cell} ipython3 |
| 156 | +def vfi(model, max_iter=10_000, tol=1e-4): |
| 157 | + """Solve the infinite-horizon Markov job search model by VFI.""" |
| 158 | +
|
| 159 | + print("Starting VFI iteration.") |
| 160 | + v = jnp.zeros_like(model.w_vals) # Initial guess |
| 161 | + i = 0 |
| 162 | + error = tol + 1 |
| 163 | +
|
| 164 | + while error > tol and i < max_iter: |
| 165 | + new_v = T(v, model) |
| 166 | + error = jnp.max(jnp.abs(new_v - v)) |
| 167 | + i += 1 |
| 168 | + v = new_v |
| 169 | +
|
| 170 | + v_star = v |
| 171 | + σ_star = get_greedy(v_star, model) |
| 172 | + return v_star, σ_star |
| 173 | +``` |
| 174 | + |
| 175 | +### Computing the solution |
| 176 | + |
| 177 | +Let's set up and solve the model. |
| 178 | + |
| 179 | +```{code-cell} ipython3 |
| 180 | +model = create_js_model() |
| 181 | +n, w_vals, P, β, c, θ = model |
| 182 | +
|
| 183 | +qe.tic() |
| 184 | +v_star, σ_star = vfi(model) |
| 185 | +vfi_time = qe.toc() |
| 186 | +``` |
| 187 | + |
| 188 | +We compute the reservation wage as the first $w$ such that $\sigma(w)=1$. |
| 189 | + |
| 190 | +```{code-cell} ipython3 |
| 191 | +res_wage = w_vals[jnp.searchsorted(σ_star, 1.0)] |
| 192 | +``` |
| 193 | + |
| 194 | +```{code-cell} ipython3 |
| 195 | +fig, ax = plt.subplots() |
| 196 | +ax.plot(w_vals, v_star, alpha=0.8, label="value function") |
| 197 | +ax.vlines((res_wage,), 150, 400, 'k', ls='--', label="reservation wage") |
| 198 | +ax.legend(frameon=False, fontsize=12, loc="lower right") |
| 199 | +ax.set_xlabel("$w$", fontsize=12) |
| 200 | +plt.show() |
| 201 | +``` |
| 202 | + |
| 203 | +## Exercise |
| 204 | + |
| 205 | +```{exercise-start} |
| 206 | +:label: job_search_1 |
| 207 | +``` |
| 208 | + |
| 209 | +In the setting above, the agent is risk-neutral vis-a-vis future utility risk. |
| 210 | + |
| 211 | +Now solve the same problem but this time assuming that the agent has risk-sensitive |
| 212 | +preferences, which are a type of nonlinear recursive preferences. |
| 213 | + |
| 214 | +The Bellman equation becomes |
| 215 | + |
| 216 | +$$ |
| 217 | + v(w) = \max |
| 218 | + \left\{ |
| 219 | + \frac{w}{1-\beta}, |
| 220 | + c + \frac{\beta}{\theta} |
| 221 | + \ln \left[ |
| 222 | + \sum_{w'} \exp(\theta v(w')) P(w, w') |
| 223 | + \right] |
| 224 | + \right\} |
| 225 | +$$ |
| 226 | + |
| 227 | + |
| 228 | +When $\theta < 0$ the agent is risk sensitive. |
| 229 | + |
| 230 | +Solve the model when $\theta = -0.1$ and compare your result to the risk neutral |
| 231 | +case. |
| 232 | + |
| 233 | +Try to interpret your result. |
| 234 | + |
| 235 | +```{exercise-end} |
| 236 | +``` |
| 237 | + |
| 238 | +```{solution-start} job_search_1 |
| 239 | +:class: dropdown |
| 240 | +``` |
| 241 | + |
| 242 | +```{code-cell} ipython3 |
| 243 | +def create_risk_sensitive_js_model( |
| 244 | + n=500, # wage grid size |
| 245 | + ρ=0.9, # wage persistence |
| 246 | + ν=0.2, # wage volatility |
| 247 | + β=0.99, # discount factor |
| 248 | + c=1.0, # unemployment compensation |
| 249 | + θ=-0.1 # risk parameter |
| 250 | + ): |
| 251 | + "Creates an instance of the job search model with Markov wages." |
| 252 | + mc = qe.tauchen(n, ρ, ν) |
| 253 | + w_vals, P = jnp.exp(mc.state_values), mc.P |
| 254 | + P = jnp.array(P) |
| 255 | + return Model(n, w_vals, P, β, c, θ) |
| 256 | +
|
| 257 | +
|
| 258 | +@jax.jit |
| 259 | +def T_rs(v, model): |
| 260 | + """ |
| 261 | + The Bellman operator Tv = max{e, c + β R v} with |
| 262 | +
|
| 263 | + e(w) = w / (1-β) and |
| 264 | +
|
| 265 | + (Rv)(w) = (1/θ) ln{E_w[ exp(θ v(W'))]} |
| 266 | +
|
| 267 | + """ |
| 268 | + n, w_vals, P, β, c, θ = model |
| 269 | + h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v))) |
| 270 | + e = w_vals / (1 - β) |
| 271 | +
|
| 272 | + return jnp.maximum(e, h) |
| 273 | +
|
| 274 | +
|
| 275 | +@jax.jit |
| 276 | +def get_greedy_rs(v, model): |
| 277 | + " Get a v-greedy policy." |
| 278 | + n, w_vals, P, β, c, θ = model |
| 279 | + e = w_vals / (1 - β) |
| 280 | + h = c + (β / θ) * jnp.log(P @ (jnp.exp(θ * v))) |
| 281 | + σ = jnp.where(e >= h, 1, 0) |
| 282 | + return σ |
| 283 | +
|
| 284 | +
|
| 285 | +
|
| 286 | +def vfi(model, max_iter=10_000, tol=1e-4): |
| 287 | + "Solve the infinite-horizon Markov job search model by VFI." |
| 288 | + print("Starting VFI iteration.") |
| 289 | + v = jnp.zeros_like(model.w_vals) # Initial guess |
| 290 | + i = 0 |
| 291 | + error = tol + 1 |
| 292 | +
|
| 293 | + while error > tol and i < max_iter: |
| 294 | + new_v = T_rs(v, model) |
| 295 | + error = jnp.max(jnp.abs(new_v - v)) |
| 296 | + i += 1 |
| 297 | + v = new_v |
| 298 | +
|
| 299 | + v_star = v |
| 300 | + σ_star = get_greedy_rs(v_star, model) |
| 301 | + return v_star, σ_star |
| 302 | +
|
| 303 | +
|
| 304 | +
|
| 305 | +model_rs = create_risk_sensitive_js_model() |
| 306 | +
|
| 307 | +n, w_vals, P, β, c, θ = model_rs |
| 308 | +
|
| 309 | +qe.tic() |
| 310 | +v_star_rs, σ_star_rs = vfi(model_rs) |
| 311 | +vfi_time = qe.toc() |
| 312 | +``` |
| 313 | + |
| 314 | +```{code-cell} ipython3 |
| 315 | +res_wage_rs = w_vals[jnp.searchsorted(σ_star_rs, 1.0)] |
| 316 | +``` |
| 317 | + |
| 318 | +```{code-cell} ipython3 |
| 319 | +fig, ax = plt.subplots() |
| 320 | +ax.plot(w_vals, v_star, alpha=0.8, label="RN $v$") |
| 321 | +ax.plot(w_vals, v_star_rs, alpha=0.8, label="RS $v$") |
| 322 | +ax.vlines((res_wage,), 150, 400, ls='--', color='darkblue', alpha=0.5, label=r"RV $\bar w$") |
| 323 | +ax.vlines((res_wage_rs,), 150, 400, ls='--', color='orange', alpha=0.5, label=r"RS $\bar w$") |
| 324 | +ax.legend(frameon=False, fontsize=12, loc="lower right") |
| 325 | +ax.set_xlabel("$w$", fontsize=12) |
| 326 | +plt.show() |
| 327 | +``` |
| 328 | + |
| 329 | +The figure shows that the reservation wage under risk sensitive preferences (RS $\bar w$) shifts down. |
| 330 | + |
| 331 | +This makes sense -- the agent does not like risk and hence is more inclined to |
| 332 | +accept the current offer, even when it's lower. |
| 333 | + |
| 334 | +```{solution-end} |
| 335 | +``` |
0 commit comments