Skip to content

Commit 80fead7

Browse files
Merge pull request #1 from logangrado/feat/args
Feature: Added optional argument `args`
2 parents 5fc7de3 + be173df commit 80fead7

File tree

4 files changed

+17
-15
lines changed

4 files changed

+17
-15
lines changed

docs/userguide.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ The :code:`solve` function has several optional arguments which the user may pro
6565

6666
.. code-block:: python
6767
68-
pybobyqa.solve(objfun, x0, bounds=None, npt=None, rhobeg=None,
69-
rhoend=1e-8, maxfun=None, nsamples=None,
68+
pybobyqa.solve(objfun, x0, args=(), bounds=None, npt=None,
69+
rhobeg=None, rhoend=1e-8, maxfun=None, nsamples=None,
7070
user_params=None, objfun_has_noise=False,
7171
scaling_within_bounds=False)
7272
7373
These arguments are:
7474

75+
* :code:`args` - a tuple of extra arguments passed to the objective function
7576
* :code:`bounds` - a tuple :code:`(lower, upper)` with the vectors :math:`a` and :math:`b` of lower and upper bounds on :math:`x` (default is :math:`a_i=-10^{20}` and :math:`b_i=10^{20}`). To set bounds for either :code:`lower` or :code:`upper`, but not both, pass a tuple :code:`(lower, None)` or :code:`(None, upper)`.
7677
* :code:`npt` - the number of interpolation points to use (default is :code:`2*len(x0)+1`). Py-BOBYQA requires :code:`n+1 <= npt <= (n+1)*(n+2)/2` for a problem with :code:`len(x0)=n`. Larger values are particularly useful for noisy problems.
7778
* :code:`rhobeg` - the initial value of the trust region radius (default is :math:`0.1\max(\|x_0\|_{\infty}, 1)`).

pybobyqa/controller.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ def able_to_do_restart(self):
9393

9494

9595
class Controller(object):
96-
def __init__(self, objfun, x0, f0, f0_nsamples, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes):
96+
def __init__(self, objfun, x0, args, f0, f0_nsamples, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes):
9797
self.objfun = objfun
9898
self.maxfun = maxfun
99+
self.args = args
99100
self.model = Model(npt, x0, f0, xl, xu, f0_nsamples, abs_tol=params("model.abs_tol"),
100101
precondition=params("interpolation.precondition"))
101102
self.nf = nf
@@ -341,7 +342,7 @@ def evaluate_objective(self, x, number_of_samples, params):
341342
if not incremented_nx:
342343
self.nx += 1
343344
incremented_nx = True
344-
f_list[i] = eval_objective(self.objfun, remove_scaling(x, self.scaling_changes), eval_num=self.nf, pt_num=self.nx,
345+
f_list[i] = eval_objective(self.objfun, remove_scaling(x, self.scaling_changes), self.args, eval_num=self.nf, pt_num=self.nx,
345346
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
346347
check_for_overflow=params("general.check_objfun_for_overflow"))
347348
num_samples_run += 1

pybobyqa/solver.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def __str__(self):
9090
return output
9191

9292

93-
def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf_so_far, nx_so_far, nsamples, params,
93+
def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf_so_far, nx_so_far, nsamples, params,
9494
diagnostic_info, scaling_changes, f0_avg_old=None, f0_nsamples_old=None):
9595
# Evaluate at x0 (keep nf, nx correct and check for f small)
9696
if f0_avg_old is None:
9797
number_of_samples = max(nsamples(rhobeg, rhobeg, 0, nruns_so_far), 1)
9898
# Evaluate the first time...
9999
nf = nf_so_far + 1
100100
nx = nx_so_far + 1
101-
f0 = eval_objective(objfun, remove_scaling(x0, scaling_changes), eval_num=nf, pt_num=nx,
101+
f0 = eval_objective(objfun, remove_scaling(x0, scaling_changes), args, eval_num=nf, pt_num=nx,
102102
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
103103
check_for_overflow=params("general.check_objfun_for_overflow"))
104104

@@ -116,7 +116,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
116116

117117
nf += 1
118118
# Don't increment nx for x0 - we did this earlier
119-
f_list[i] = eval_objective(objfun, remove_scaling(x0, scaling_changes), eval_num=nf, pt_num=nx,
119+
f_list[i] = eval_objective(objfun, remove_scaling(x0, scaling_changes), args, eval_num=nf, pt_num=nx,
120120
full_x_thresh=params("logging.n_to_print_whole_x_vector"),
121121
check_for_overflow=params("general.check_objfun_for_overflow"))
122122
num_samples_run += 1
@@ -136,7 +136,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
136136
nx = nx_so_far
137137

138138
# Initialise controller
139-
control = Controller(objfun, x0, f0_avg, num_samples_run, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes)
139+
control = Controller(objfun, x0, args, f0_avg, num_samples_run, xl, xu, npt, rhobeg, rhoend, nf, nx, maxfun, params, scaling_changes)
140140

141141
# Initialise interpolation set
142142
number_of_samples = max(nsamples(control.delta, control.rho, 0, nruns_so_far), 1)
@@ -631,7 +631,7 @@ def solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf
631631
return x, f, gradmin, hessmin, nsamples, control.nf, control.nx, nruns_so_far, exit_info, diagnostic_info
632632

633633

634-
def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=None, nsamples=None, user_params=None,
634+
def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=None, nsamples=None, user_params=None,
635635
objfun_has_noise=False, scaling_within_bounds=False):
636636
n = len(x0)
637637

@@ -765,7 +765,7 @@ def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=No
765765
nf = 0
766766
nx = 0
767767
xmin, fmin, gradmin, hessmin, nsamples_min, nf, nx, nruns, exit_info, diagnostic_info = \
768-
solve_main(objfun, x0, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
768+
solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
769769
diagnostic_info, scaling_changes)
770770

771771
# Hard restarts loop
@@ -778,11 +778,11 @@ def solve(objfun, x0, bounds=None, npt=None, rhobeg=None, rhoend=1e-8, maxfun=No
778778
% (fmin, nf, rhobeg, rhoend))
779779
if params("restarts.hard.use_old_fk"):
780780
xmin2, fmin2, gradmin2, hessmin2, nsamples2, nf, nx, nruns, exit_info, diagnostic_info = \
781-
solve_main(objfun, xmin, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
781+
solve_main(objfun, xmin, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
782782
diagnostic_info, scaling_changes, f0_avg_old=fmin, f0_nsamples_old=nsamples_min)
783783
else:
784784
xmin2, fmin2, gradmin2, hessmin2, nsamples2, nf, nx, nruns, exit_info, diagnostic_info = \
785-
solve_main(objfun, xmin, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
785+
solve_main(objfun, xmin, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns, nf, nx, nsamples, params,
786786
diagnostic_info, scaling_changes)
787787

788788
if fmin2 < fmin or np.isnan(fmin):

pybobyqa/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ def sumsq(x):
4444
return np.dot(x, x)
4545

4646

47-
def eval_objective(objfun, x, verbose=True, eval_num=0, pt_num=0, full_x_thresh=6, check_for_overflow=True):
47+
def eval_objective(objfun, x, args, verbose=True, eval_num=0, pt_num=0, full_x_thresh=6, check_for_overflow=True):
4848
# Evaluate objective function
4949
if check_for_overflow:
5050
try:
51-
f = objfun(x)
51+
f = objfun(x, *args)
5252
except OverflowError:
5353
f = sys.float_info.max
5454
else:
55-
f = objfun(x)
55+
f = objfun(x, *args)
5656

5757
if verbose:
5858
if len(x) < full_x_thresh:

0 commit comments

Comments
 (0)