Skip to content

Commit 551a016

Browse files
committed
Bug fixes.
1 parent 77e5ed4 commit 551a016

12 files changed

+366
-87
lines changed

src/frameworks/IAIPAL.m

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@
4242

4343
% Initialize special constraint functions
4444
norm_fn = params.norm_fn;
45-
pc_fn = params.set_projector; % projection onto the cone K.
4645

4746
% Cone projector function on the point q = p + c * g(x) on the cones:
4847
% -K (primal) and K^* (dual).
4948
function [dual_point, primal_point]= cone_proj(x, p, c)
5049
p_step = p + c * params.constr_fn(x);
51-
primal_point = -pc_fn(-p_step);
50+
primal_point = -params.set_projector(-p_step);
5251
dual_point = p_step - primal_point;
5352
end
5453

@@ -65,7 +64,7 @@
6564
function val = wrap_f_s(x)
6665
[~, primal_proj_point]= cone_proj(x, p, c);
6766
p_step = p + c * params.constr_fn(x);
68-
dist_val = norm_fn((-p_step) - primal_proj_point);
67+
dist_val = norm_fn(p_step - primal_proj_point);
6968
val = 1 / (2 * c) * (dist_val ^ 2 - norm_fn(p) ^ 2);
7069
end
7170
function val = wrap_grad_f_s(x)

src/frameworks/penalty.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@
191191
if (~isfield(params, 'i_logging'))
192192
params.i_logging = false;
193193
end
194+
if (~isfield(params, 'i_debug'))
195+
params.i_debug = false;
196+
end
194197
if (~isfield(params, 'i_reset_prox_center'))
195198
params.i_reset_prox_center = false;
196199
end

src/solvers/ACG.m

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@
2929
%
3030

3131
% Set some ACG global tolerances.
32-
INEQ_COND_ERR_TOL = 1e-6;
32+
INEQ_TOL = 1e-6;
3333
CURV_TOL = 1e-6;
34-
DIV_TOL = 1e-6;
3534

3635
%% PRE-PROCESSING
3736

@@ -118,7 +117,7 @@
118117
error('Unknown ACG steptype!');
119118
end
120119

121-
% Safeguard against the case where (L == mu), i.e. nu = Inf.
120+
% Safeguard against the case where (L == mu), i.e. lamK = Inf.
122121
L = max(L, mu + CURV_TOL);
123122

124123
% Set up the oracle at x0
@@ -131,9 +130,9 @@
131130
function [local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev)
132131

133132
% Simple quantities.
134-
nu = nu_fn(mu, L);
135-
nu_prev = (1 + mu * A_prev) * nu;
136-
a_prev = (nu_prev + sqrt(nu_prev ^ 2 + 4 * nu_prev * A_prev)) / 2;
133+
lamK = lamK_fn(mu, L);
134+
tauK = (1 + mu * A_prev) * lamK;
135+
a_prev = (tauK + sqrt(tauK ^ 2 + 4 * tauK * A_prev)) / 2;
137136
A = A_prev + a_prev;
138137
x_tilde_prev = (A_prev / A) * y_prev + a_prev / A * x_prev;
139138

@@ -143,20 +142,26 @@
143142
grad_f_s_at_x_tilde_prev = o_x_tilde_prev.grad_f_s();
144143

145144
% Oracle at y.
146-
y_prox_mult = nu / (1 + nu * mu);
145+
y_prox_mult = lamK / (1 + lamK * mu);
147146
y_prox_ctr = x_tilde_prev - y_prox_mult * grad_f_s_at_x_tilde_prev;
148147
[y, o_y] = get_y(y_prox_ctr, y_prox_mult);
149148
f_s_at_y = o_y.f_s();
150149

151150
% Estimate of L based on y and x_tilde_prev.
152-
local_L_est = max(0, 2 * (f_s_at_y - (f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev))) / ...
153-
norm_fn(y - x_tilde_prev) ^ 2);
151+
LHS = f_s_at_y - (f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev));
152+
dist_xt_y = norm_fn(y - x_tilde_prev);
153+
RHS = L * dist_xt_y ^ 2 / 2;
154+
local_L_est = max(0, 2 * LHS / dist_xt_y ^ 2);
154155

155156
% Save auxiliary quantities.
157+
aux_struct.LHS = LHS;
158+
aux_struct.RHS = RHS;
159+
aux_struct.dist_xt_y = dist_xt_y;
160+
aux_struct.descent_cond = (LHS <= RHS + INEQ_TOL);
156161
aux_struct.y = y;
157162
aux_struct.o_y = o_y;
158-
aux_struct.nu = nu;
159-
aux_struct.nu_prev = nu_prev;
163+
aux_struct.lamK = lamK;
164+
aux_struct.tauK = tauK;
160165
aux_struct.a_prev = a_prev;
161166
aux_struct.A = A;
162167
aux_struct.x_tilde_prev = x_tilde_prev;
@@ -206,22 +211,23 @@
206211
[local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev);
207212
iter = iter + 1;
208213

209-
% Adjust if the local estimate is smaller, up to a point.
210-
if (L / 2 < local_L_est && local_L_est < L)
211-
L = local_L_est;
212-
[local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev);
213-
iter = iter + 1;
214-
end
215-
216214
% Update based on the value of the local L compared to the current estimate of L.
217-
while (L < min([L_max, local_L_est]))
218-
if (norm_fn(aux_struct.x_tilde_prev - aux_struct.y) ^ 2 <= DIV_TOL)
219-
L = min(L_max, L * params.mult_L);
220-
else
221-
L = min(L_max, L * params.mult_L);
222-
end
223-
[local_L_est, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev);
215+
while (~aux_struct.descent_cond)
216+
L = min(L_max, L * params.mult_L);
217+
[~, aux_struct] = compute_approx_iter(L, mu, A_prev, y_prev, x_prev);
224218
iter = iter + 1;
219+
220+
% DEBUG ONLY
221+
if (params.i_debug)
222+
diff = abs(aux_struct.RHS - aux_struct.LHS);
223+
dist_xt_y = aux_struct.dist_xt_y;
224+
disp(table(local_L_est, L, L_max, dist_xt_y, diff, aux_struct.LHS, aux_struct.RHS , aux_struct.descent_cond));
225+
end
226+
% END DEBUG
227+
228+
if (L >= L_max && ~aux_struct.descent_cond)
229+
error('Theoretical upper bound on the upper curvature L_max does not appear to be correct!');
230+
end
225231
if (toc(t_start) > time_limit)
226232
break;
227233
end
@@ -230,8 +236,8 @@
230236
% Load auxiliary quantities
231237
y = aux_struct.y;
232238
o_y = aux_struct.o_y;
233-
nu = aux_struct.nu;
234-
nu_prev = aux_struct.nu_prev;
239+
lamK = aux_struct.lamK;
240+
tauK = aux_struct.tauK;
235241
a_prev = aux_struct.a_prev;
236242
A = aux_struct.A;
237243
x_tilde_prev = aux_struct.x_tilde_prev;
@@ -242,9 +248,9 @@
242248
elseif strcmp(params.acg_steptype, "constant")
243249

244250
% Iteration parameters.
245-
nu = nu_fn(mu, L);
246-
nu_prev = (1 + mu * A_prev) * nu;
247-
a_prev = (nu_prev + sqrt(nu_prev ^ 2 + 4 * nu_prev * A_prev)) / 2;
251+
lamK = lamK_fn(mu, L);
252+
tauK = (1 + mu * A_prev) * lamK;
253+
a_prev = (tauK + sqrt(tauK ^ 2 + 4 * tauK * A_prev)) / 2;
248254
A = A_prev + a_prev;
249255
x_tilde_prev = (A_prev / A) * y_prev + (a_prev / A) * x_prev;
250256

@@ -254,7 +260,7 @@
254260
grad_f_s_at_x_tilde_prev = o_x_tilde_prev.grad_f_s();
255261

256262
% Oracle at y.
257-
y_prox_mult = nu / (1 + nu * mu);
263+
y_prox_mult = lamK / (1 + lamK * mu);
258264
y_prox_ctr = x_tilde_prev - y_prox_mult * grad_f_s_at_x_tilde_prev;
259265
[y, o_y] = get_y(y_prox_ctr, y_prox_mult);
260266

@@ -280,14 +286,14 @@
280286
%% COMPUTE (u, η), Γ, and x.
281287

282288
% Compute x and u.
283-
x = 1 / (1 + mu * A) * (x_prev - a_prev / nu * (x_tilde_prev - y) + mu * (A_prev * x_prev + a_prev * y));
289+
x = 1 / (1 + mu * A) * (x_prev - a_prev / lamK * (x_tilde_prev - y) + mu * (A_prev * x_prev + a_prev * y));
284290
u = (x0 - x) / A;
285291

286292
% Compute eta.
287293
if strcmp(params.eta_type, 'recursive')
288294
% Recursive
289295
gamma_at_x = f_n_at_y + f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev) + ...
290-
(mu / 2) * norm_fn(y - x_tilde_prev) ^ 2 + 1 / nu * prod_fn(x_tilde_prev - y, x - y) + ...
296+
(mu / 2) * norm_fn(y - x_tilde_prev) ^ 2 + 1 / lamK * prod_fn(x_tilde_prev - y, x - y) + ...
291297
(mu / 2) * norm_fn(x - y) ^ 2;
292298
Gamma_at_x = a_prev / A * gamma_at_x + A_prev / A * Gamma_at_x_prev + A_prev / A * prod_fn(grad_Gamma_at_x_prev, x - x_prev) + ...
293299
(mu / 2) * A_prev / A * norm_fn(x - x_prev) ^ 2;
@@ -296,8 +302,8 @@
296302
elseif strcmp(params.eta_type, 'accumulative')
297303
% Accumulative
298304
p_at_y = f_s_at_x_tilde_prev + prod_fn(grad_f_s_at_x_tilde_prev, y - x_tilde_prev) + mu / 2 * norm_fn(y - x_tilde_prev) ^ 2 + f_n_at_y;
299-
sci = p_at_y + mu / 2 * norm_fn(y) ^ 2 - (1 / nu) * prod_fn(y, x_tilde_prev - y);
300-
svi = - mu * y + (1 / nu) * (x_tilde_prev - y);
305+
sci = p_at_y + mu / 2 * norm_fn(y) ^ 2 - (1 / lamK) * prod_fn(y, x_tilde_prev - y);
306+
svi = - mu * y + (1 / lamK) * (x_tilde_prev - y);
301307
sni = mu / 2;
302308
scSum = scSum + a_prev * sci;
303309
svSum = svSum + a_prev * svi;
@@ -318,7 +324,7 @@
318324
% Check the negativity of eta in a relative sense.
319325
if strcmp(termination_type, "aipp")
320326
relative_exact_eta = exact_eta / max([norm_fn(u + x0 - y) ^ 2 / 2, 0.01]);
321-
if (relative_exact_eta < -INEQ_COND_ERR_TOL)
327+
if (relative_exact_eta < -INEQ_TOL)
322328
error(['eta is negative with a value of ', num2str(exact_eta)]);
323329
end
324330
end
@@ -334,7 +340,7 @@
334340
small_gd = f_at_y + prod_fn(u, x0 - y) - eta;
335341
del_gd = large_gd - small_gd;
336342
base = max([abs(large_gd), abs(small_gd), 0.01]);
337-
if (del_gd / base < -INEQ_COND_ERR_TOL)
343+
if (del_gd / base < -INEQ_TOL)
338344
model.status = -1;
339345
break;
340346
end
@@ -346,7 +352,7 @@
346352
large_gd = norm_fn(y - x0) ^ 2;
347353
del_gd = large_gd - small_gd;
348354
base = max([abs(large_gd), abs(small_gd), 0.01]);
349-
if (del_gd / base < -INEQ_COND_ERR_TOL)
355+
if (del_gd / base < -INEQ_TOL)
350356
model.status = -2;
351357
break;
352358
end
@@ -359,7 +365,7 @@
359365
large_gd = norm_fn(y - x0) ^ 2;
360366
del_gd = large_gd - small_gd;
361367
base = max([abs(large_gd), abs(small_gd), 0.01]);
362-
if (del_gd / base < -INEQ_COND_ERR_TOL)
368+
if (del_gd / base < -INEQ_TOL)
363369
model.status = -2;
364370
break;
365371
end
@@ -369,13 +375,13 @@
369375

370376
% Termination for the AIPP method (Phase 1).
371377
if strcmp(termination_type, "aipp")
372-
if (norm_fn(u) ^ 2 + 2 * eta <= sigma * norm_fn(x0 - y + u) ^ 2 + INEQ_COND_ERR_TOL)
378+
if (norm_fn(u) ^ 2 + 2 * eta <= sigma * norm_fn(x0 - y + u) ^ 2 + INEQ_TOL)
373379
break;
374380
end
375381

376382
% Termination for the AIPP method (with sigma square).
377383
elseif strcmp(termination_type, "aipp_sqr")
378-
if (norm_fn(u) ^ 2 + 2 * eta <= sigma ^ 2 * norm_fn(x0 - y + u) ^ 2 + INEQ_COND_ERR_TOL)
384+
if (norm_fn(u) ^ 2 + 2 * eta <= sigma ^ 2 * norm_fn(x0 - y + u) ^ 2 + INEQ_TOL)
379385
break;
380386
end
381387

@@ -426,7 +432,7 @@
426432

427433
% Update iterates.
428434
if (strcmp(params.eta_type, 'recursive'))
429-
grad_gamma_at_x = 1 / nu * (x_tilde_prev - y) + mu * (x - y);
435+
grad_gamma_at_x = 1 / lamK * (x_tilde_prev - y) + mu * (x - y);
430436
grad_Gamma_at_x = a_prev / A * grad_gamma_at_x + A_prev / A * grad_Gamma_at_x_prev + A_prev / A * mu * (x - x_prev);
431437
Gamma_at_x_prev = Gamma_at_x;
432438
grad_Gamma_at_x_prev = grad_Gamma_at_x;
@@ -466,8 +472,8 @@
466472

467473
%% HELPER FUNCTIONS
468474

469-
function out_nu = nu_fn(mu, L)
470-
out_nu = 1 / (L - mu);
475+
function out_lamK =lamK_fn(mu, L)
476+
out_lamK = 1 / (L - mu);
471477
end
472478

473479
% Fills in parameters that were not set as input.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
% Solve a multivariate nonconvex quadratically constrained quadratic programming
2+
% problem constrained to a box using MULTIPLE SOLVERS.
3+
run('../../../../init.m');
4+
format long
5+
6+
% Comment out later.
7+
dimN = 250;
8+
r = 1;
9+
m = 1;
10+
M = 1e4;
11+
12+
% Run an instance via the command line.
13+
print_tbls(dimN, r, m, M);
14+
15+
%% Utility functions
16+
function print_tbls(dimN, r, m, M)
17+
18+
% Initialize
19+
seed = 77777;
20+
dimM = 10;
21+
global_tol = 1e-5;
22+
disp(table(dimN, r, m, M));
23+
o_tbl = run_experiment(M, m, dimM, dimN, -r, r, seed, global_tol);
24+
disp(o_tbl);
25+
26+
end
27+
function o_tbl = run_experiment(M, m, dimM, dimN, x_l, x_u, seed, global_tol)
28+
29+
[oracle, hparams] = test_fn_quad_box_constr_02(M, m, seed, dimM, dimN, x_l, x_u);
30+
31+
% Set up the termination function.
32+
function proj = proj_dh(a, b)
33+
I1 = (abs(a - x_l) < 1e-12);
34+
I2 = (abs(a - x_u) <= 1e-12);
35+
I3 = (abs(a - x_l) > 1e-12 & abs(a - x_u) > 1e-12);
36+
proj = b;
37+
proj(I1) = min(0, b(I1));
38+
proj(I2) = max(0, b(I2));
39+
proj(I3) = 0;
40+
end
41+
function proj = proj_NKt(~, b)
42+
proj = min(0, b);
43+
end
44+
o_at_x0 = copy(oracle);
45+
o_at_x0.eval(hparams.x0);
46+
g0 = hparams.constr_fn(hparams.x0);
47+
rho = global_tol * (1 + hparams.norm_fn(o_at_x0.grad_f_s()));
48+
eta = global_tol * (1 + hparams.norm_fn(g0 - hparams.set_projector(g0)));
49+
term_wrap = @(x,p) ...
50+
termination_check(x, p, o_at_x0, hparams.constr_fn, hparams.grad_constr_fn, @proj_dh, @proj_NKt, hparams.norm_fn, rho, eta);
51+
52+
% Create the Model object and specify the solver.
53+
ncvx_qc_qp = ConstrCompModel(oracle);
54+
55+
% Set the curvatures and the starting point x0.
56+
ncvx_qc_qp.x0 = hparams.x0;
57+
ncvx_qc_qp.M = hparams.M;
58+
ncvx_qc_qp.m = hparams.m;
59+
ncvx_qc_qp.K_constr = hparams.K_constr;
60+
ncvx_qc_qp.L_constr = hparams.L_constr;
61+
62+
% Set the tolerances
63+
ncvx_qc_qp.opt_tol = global_tol;
64+
ncvx_qc_qp.feas_tol = global_tol;
65+
ncvx_qc_qp.time_limit = 18000;
66+
ncvx_qc_qp.iter_limit = 1000000;
67+
68+
% Add linear constraints
69+
ncvx_qc_qp.constr_fn = hparams.constr_fn;
70+
ncvx_qc_qp.grad_constr_fn = hparams.grad_constr_fn;
71+
ncvx_qc_qp.set_projector = hparams.set_projector;
72+
ncvx_qc_qp.dual_cone_projector = hparams.dual_cone_projector;
73+
74+
% Use a relative termination criterion.
75+
ncvx_qc_qp.feas_type = 'relative';
76+
ncvx_qc_qp.opt_type = 'relative';
77+
78+
% Create some basic hparams.
79+
base_hparam = struct();
80+
base_hparam.i_debug = true;
81+
base_hparam.termination_fn = term_wrap;
82+
base_hparam.check_all_terminations = true;
83+
84+
% Create the IAPIAL hparams.
85+
ipl_hparam = base_hparam;
86+
ipl_hparam.acg_steptype = 'constant';
87+
ipl_hparam.sigma_min = 1/sqrt(2);
88+
ipla_hparam = base_hparam;
89+
ipla_hparam.acg_steptype = 'variable';
90+
ipla_hparam.init_mult_L = 0.5;
91+
ipla_hparam.sigma_min = 1/sqrt(2);
92+
93+
% Run a benchmark test and print the summary.
94+
hparam_arr = {ipla_hparam, ipl_hparam};
95+
name_arr = {'IPL_A', 'IPL'};
96+
framework_arr = {@IAIPAL, @IAIPAL};
97+
solver_arr = {@ECG, @ECG};
98+
99+
% Run the test.
100+
[summary_tables, ~] = run_CCM_benchmark(ncvx_qc_qp, framework_arr, solver_arr, hparam_arr, name_arr);
101+
o_tbl = summary_tables.all;
102+
103+
end
104+

tests/papers/nl_iapial/extra/nl_iapial_qcqp_iALM_only_params.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function print_tbls(dimN, r, m, M)
5656
% Set the tolerances
5757
ncvx_qc_qp.opt_tol = global_tol;
5858
ncvx_qc_qp.feas_tol = global_tol;
59-
ncvx_qc_qp.time_limit = Inf;
59+
ncvx_qc_qp.time_limit = 6000;
6060
ncvx_qc_qp.iter_limit = 1000000;
6161

6262
% Add linear constraints

0 commit comments

Comments
 (0)