|
29 | 29 | %
|
30 | 30 |
|
31 | 31 | % Set some ACG global tolerances.
|
32 |
| - INEQ_COND_ERR_TOL = 1e-6; |
| 32 | + INEQ_TOL = 1e-6; |
33 | 33 | CURV_TOL = 1e-6;
|
34 |
| - DIV_TOL = 1e-6; |
35 | 34 |
|
36 | 35 | %% PRE-PROCESSING
|
37 | 36 |
|
|
118 | 117 | error('Unknown ACG steptype!');
|
119 | 118 | end
|
120 | 119 |
|
121 |
| - % Safeguard against the case where (L == mu), i.e. nu = Inf. |
| 120 | + % Safeguard against the case where (L == mu), i.e. lamK = Inf. |
122 | 121 | L = max(L, mu + CURV_TOL);
|
123 | 122 |
|
124 | 123 | % Set up the oracle at x0
|
|
131 | 130 | function [local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev)
|
132 | 131 |
|
133 | 132 | % Simple quantities.
|
134 |
| - nu = nu_fn(mu, L); |
135 |
| - nu_prev = (1 + mu * A_prev) * nu; |
136 |
| - a_prev = (nu_prev + sqrt(nu_prev ^ 2 + 4 * nu_prev * A_prev)) / 2; |
| 133 | + lamK = lamK_fn(mu, L); |
| 134 | + tauK = (1 + mu * A_prev) * lamK; |
| 135 | + a_prev = (tauK + sqrt(tauK ^ 2 + 4 * tauK * A_prev)) / 2; |
137 | 136 | A = A_prev + a_prev;
|
138 | 137 | x_tilde_prev = (A_prev / A) * y_prev + a_prev / A * x_prev;
|
139 | 138 |
|
|
143 | 142 | grad_f_s_at_x_tilde_prev = o_x_tilde_prev.grad_f_s();
|
144 | 143 |
|
145 | 144 | % Oracle at y.
|
146 |
| - y_prox_mult = nu / (1 + nu * mu); |
| 145 | + y_prox_mult = lamK / (1 + lamK * mu); |
147 | 146 | y_prox_ctr = x_tilde_prev - y_prox_mult * grad_f_s_at_x_tilde_prev;
|
148 | 147 | [y, o_y] = get_y(y_prox_ctr, y_prox_mult);
|
149 | 148 | f_s_at_y = o_y.f_s();
|
150 | 149 |
|
151 | 150 | % Estimate of L based on y and x_tilde_prev.
|
152 |
| - local_L_est = max(0, 2 * (f_s_at_y - (f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev))) / ... |
153 |
| - norm_fn(y - x_tilde_prev) ^ 2); |
| 151 | + LHS = f_s_at_y - (f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev)); |
| 152 | + dist_xt_y = norm_fn(y - x_tilde_prev); |
| 153 | + RHS = L * dist_xt_y ^ 2 / 2; |
| 154 | + local_L_est = max(0, 2 * LHS / dist_xt_y ^ 2); |
154 | 155 |
|
155 | 156 | % Save auxiliary quantities.
|
| 157 | + aux_struct.LHS = LHS; |
| 158 | + aux_struct.RHS = RHS; |
| 159 | + aux_struct.dist_xt_y = dist_xt_y; |
| 160 | + aux_struct.descent_cond = (LHS <= RHS + INEQ_TOL); |
156 | 161 | aux_struct.y = y;
|
157 | 162 | aux_struct.o_y = o_y;
|
158 |
| - aux_struct.nu = nu; |
159 |
| - aux_struct.nu_prev = nu_prev; |
| 163 | + aux_struct.lamK = lamK; |
| 164 | + aux_struct.tauK = tauK; |
160 | 165 | aux_struct.a_prev = a_prev;
|
161 | 166 | aux_struct.A = A;
|
162 | 167 | aux_struct.x_tilde_prev = x_tilde_prev;
|
|
206 | 211 | [local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev);
|
207 | 212 | iter = iter + 1;
|
208 | 213 |
|
209 |
| - % Adjust if the local estimate is smaller, up to a point. |
210 |
| - if (L / 2 < local_L_est && local_L_est < L) |
211 |
| - L = local_L_est; |
212 |
| - [local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev); |
213 |
| - iter = iter + 1; |
214 |
| - end |
215 |
| - |
216 | 214 | % Update based on the value of the local L compared to the current estimate of L.
|
217 |
| - while (L < min([L_max, local_L_est])) |
218 |
| - if (norm_fn(aux_struct.x_tilde_prev - aux_struct.y) ^ 2 <= DIV_TOL) |
219 |
| - L = min(L_max, L * params.mult_L); |
220 |
| - else |
221 |
| - L = min(L_max, L * params.mult_L); |
222 |
| - end |
223 |
| - [local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev); |
| 215 | + while (~aux_struct.descent_cond) |
| 216 | + L = min(L_max, L * params.mult_L); |
| 217 | + [~, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev); |
224 | 218 | iter = iter + 1;
|
| 219 | + |
| 220 | + % DEBUG ONLY |
| 221 | + if (params.i_debug) |
| 222 | + diff = abs(aux_struct.RHS - aux_struct.LHS); |
| 223 | + dist_xt_y = aux_struct.dist_xt_y; |
| 224 | + disp(table(local_L_est, L, L_max, dist_xt_y, diff, aux_struct.LHS, aux_struct.RHS , aux_struct.descent_cond)); |
| 225 | + end |
| 226 | + % END DEBUG |
| 227 | + |
| 228 | + if (L >= L_max && ~aux_struct.descent_cond) |
| 229 | + error('Theoretical upper bound on the upper curvature L_max does not appear to be correct!'); |
| 230 | + end |
225 | 231 | if (toc(t_start) > time_limit)
|
226 | 232 | break;
|
227 | 233 | end
|
|
230 | 236 | % Load auxiliary quantities
|
231 | 237 | y = aux_struct.y;
|
232 | 238 | o_y = aux_struct.o_y;
|
233 |
| - nu = aux_struct.nu; |
234 |
| - nu_prev = aux_struct.nu_prev; |
| 239 | + lamK = aux_struct.lamK; |
| 240 | + tauK = aux_struct.tauK; |
235 | 241 | a_prev = aux_struct.a_prev;
|
236 | 242 | A = aux_struct.A;
|
237 | 243 | x_tilde_prev = aux_struct.x_tilde_prev;
|
|
242 | 248 | elseif strcmp(params.acg_steptype, "constant")
|
243 | 249 |
|
244 | 250 | % Iteration parameters.
|
245 |
| - nu = nu_fn(mu, L); |
246 |
| - nu_prev = (1 + mu * A_prev) * nu; |
247 |
| - a_prev = (nu_prev + sqrt(nu_prev ^ 2 + 4 * nu_prev * A_prev)) / 2; |
| 251 | + lamK = lamK_fn(mu, L); |
| 252 | + tauK = (1 + mu * A_prev) * lamK; |
| 253 | + a_prev = (tauK + sqrt(tauK ^ 2 + 4 * tauK * A_prev)) / 2; |
248 | 254 | A = A_prev + a_prev;
|
249 | 255 | x_tilde_prev = (A_prev / A) * y_prev + (a_prev / A) * x_prev;
|
250 | 256 |
|
|
254 | 260 | grad_f_s_at_x_tilde_prev = o_x_tilde_prev.grad_f_s();
|
255 | 261 |
|
256 | 262 | % Oracle at y.
|
257 |
| - y_prox_mult = nu / (1 + nu * mu); |
| 263 | + y_prox_mult = lamK / (1 + lamK * mu); |
258 | 264 | y_prox_ctr = x_tilde_prev - y_prox_mult * grad_f_s_at_x_tilde_prev;
|
259 | 265 | [y, o_y] = get_y(y_prox_ctr, y_prox_mult);
|
260 | 266 |
|
|
280 | 286 | %% COMPUTE (u, η), Γ, and x.
|
281 | 287 |
|
282 | 288 | % Compute x and u.
|
283 |
| - x = 1 / (1 + mu * A) * (x_prev - a_prev / nu * (x_tilde_prev - y) + mu * (A_prev * x_prev + a_prev * y)); |
| 289 | + x = 1 / (1 + mu * A) * (x_prev - a_prev / lamK * (x_tilde_prev - y) + mu * (A_prev * x_prev + a_prev * y)); |
284 | 290 | u = (x0 - x) / A;
|
285 | 291 |
|
286 | 292 | % Compute eta.
|
287 | 293 | if strcmp(params.eta_type, 'recursive')
|
288 | 294 | % Recursive
|
289 | 295 | gamma_at_x = f_n_at_y + f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev) + ...
|
290 |
| - (mu / 2) * norm_fn(y - x_tilde_prev) ^ 2 + 1 / nu * prod_fn(x_tilde_prev - y, x - y) + ... |
| 296 | + (mu / 2) * norm_fn(y - x_tilde_prev) ^ 2 + 1 / lamK * prod_fn(x_tilde_prev - y, x - y) + ... |
291 | 297 | (mu / 2) * norm_fn(x - y) ^ 2;
|
292 | 298 | Gamma_at_x = a_prev / A * gamma_at_x + A_prev / A * Gamma_at_x_prev + A_prev / A * prod_fn(grad_Gamma_at_x_prev, x - x_prev) + ...
|
293 | 299 | (mu / 2) * A_prev / A * norm_fn(x - x_prev) ^ 2;
|
|
296 | 302 | elseif strcmp(params.eta_type, 'accumulative')
|
297 | 303 | % Accumulative
|
298 | 304 | p_at_y = f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev) + mu / 2 * norm_fn(y - x_tilde_prev) ^ 2 + f_n_at_y;
|
299 |
| - sci = p_at_y + mu / 2 * norm_fn(y) ^ 2 - (1 / nu) * prod_fn(y, x_tilde_prev - y); |
300 |
| - svi = - mu * y + (1 / nu) * (x_tilde_prev - y); |
| 305 | + sci = p_at_y + mu / 2 * norm_fn(y) ^ 2 - (1 / lamK) * prod_fn(y, x_tilde_prev - y); |
| 306 | + svi = - mu * y + (1 / lamK) * (x_tilde_prev - y); |
301 | 307 | sni = mu / 2;
|
302 | 308 | scSum = scSum + a_prev * sci;
|
303 | 309 | svSum = svSum + a_prev * svi;
|
|
318 | 324 | % Check the negativity of eta in a relative sense.
|
319 | 325 | if strcmp(termination_type, "aipp")
|
320 | 326 | relative_exact_eta = exact_eta / max([norm_fn(u + x0 - y) ^ 2 / 2, 0.01]);
|
321 |
| - if (relative_exact_eta < -INEQ_COND_ERR_TOL) |
| 327 | + if (relative_exact_eta < -INEQ_TOL) |
322 | 328 | error(['eta is negative with a value of ', num2str(exact_eta)]);
|
323 | 329 | end
|
324 | 330 | end
|
|
334 | 340 | small_gd = f_at_y + prod_fn(u, x0 - y) - eta;
|
335 | 341 | del_gd = large_gd - small_gd;
|
336 | 342 | base = max([abs(large_gd), abs(small_gd), 0.01]);
|
337 |
| - if (del_gd / base < -INEQ_COND_ERR_TOL) |
| 343 | + if (del_gd / base < -INEQ_TOL) |
338 | 344 | model.status = -1;
|
339 | 345 | break;
|
340 | 346 | end
|
|
346 | 352 | large_gd = norm_fn(y - x0) ^ 2;
|
347 | 353 | del_gd = large_gd - small_gd;
|
348 | 354 | base = max([abs(large_gd), abs(small_gd), 0.01]);
|
349 |
| - if (del_gd / base < -INEQ_COND_ERR_TOL) |
| 355 | + if (del_gd / base < -INEQ_TOL) |
350 | 356 | model.status = -2;
|
351 | 357 | break;
|
352 | 358 | end
|
|
359 | 365 | large_gd = norm_fn(y - x0) ^ 2;
|
360 | 366 | del_gd = large_gd - small_gd;
|
361 | 367 | base = max([abs(large_gd), abs(small_gd), 0.01]);
|
362 |
| - if (del_gd / base < -INEQ_COND_ERR_TOL) |
| 368 | + if (del_gd / base < -INEQ_TOL) |
363 | 369 | model.status = -2;
|
364 | 370 | break;
|
365 | 371 | end
|
|
369 | 375 |
|
370 | 376 | % Termination for the AIPP method (Phase 1).
|
371 | 377 | if strcmp(termination_type, "aipp")
|
372 |
| - if (norm_fn(u) ^ 2 + 2 * eta <= sigma * norm_fn(x0 - y + u) ^ 2 + INEQ_COND_ERR_TOL) |
| 378 | + if (norm_fn(u) ^ 2 + 2 * eta <= sigma * norm_fn(x0 - y + u) ^ 2 + INEQ_TOL) |
373 | 379 | break;
|
374 | 380 | end
|
375 | 381 |
|
376 | 382 | % Termination for the AIPP method (with sigma square).
|
377 | 383 | elseif strcmp(termination_type, "aipp_sqr")
|
378 |
| - if (norm_fn(u) ^ 2 + 2 * eta <= sigma ^ 2 * norm_fn(x0 - y + u) ^ 2 + INEQ_COND_ERR_TOL) |
| 384 | + if (norm_fn(u) ^ 2 + 2 * eta <= sigma ^ 2 * norm_fn(x0 - y + u) ^ 2 + INEQ_TOL) |
379 | 385 | break;
|
380 | 386 | end
|
381 | 387 |
|
|
426 | 432 |
|
427 | 433 | % Update iterates.
|
428 | 434 | if (strcmp(params.eta_type, 'recursive'))
|
429 |
| - grad_gamma_at_x = 1 / nu * (x_tilde_prev - y) + mu * (x - y); |
| 435 | + grad_gamma_at_x = 1 / lamK * (x_tilde_prev - y) + mu * (x - y); |
430 | 436 | grad_Gamma_at_x = a_prev / A * grad_gamma_at_x + A_prev / A * grad_Gamma_at_x_prev + A_prev / A * mu * (x - x_prev);
|
431 | 437 | Gamma_at_x_prev = Gamma_at_x;
|
432 | 438 | grad_Gamma_at_x_prev = grad_Gamma_at_x;
|
|
466 | 472 |
|
467 | 473 | %% HELPER FUNCTIONS
|
468 | 474 |
|
469 |
| -function out_nu = nu_fn(mu, L) |
470 |
| - out_nu = 1 / (L - mu); |
| 475 | +function out_lamK =lamK_fn(mu, L) |
| 476 | + out_lamK = 1 / (L - mu); |
471 | 477 | end
|
472 | 478 |
|
473 | 479 | % Fills in parameters that were not set as input.
|
|
0 commit comments