Skip to content

Commit

Permalink
Merge pull request scipy#8693 from pv/pr-8430
Browse files Browse the repository at this point in the history
ENH: optimize: call callback with a copy for scipy.optimize
  • Loading branch information
dlax authored Apr 9, 2018
2 parents 150e7f7 + 7f43e34 commit ba177cb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 6 deletions.
4 changes: 2 additions & 2 deletions scipy/optimize/_trustregion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def _minimize_trust_region(fun, x0, args=(), jac=None, hess=None, hessp=None,

# append the best guess, call back, increment the iteration count
if return_all:
allvecs.append(x)
allvecs.append(np.copy(x))
if callback is not None:
callback(x)
callback(np.copy(x))
k += 1

# check if the gradient is small enough to stop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def stop_criteria(state, x, last_iteration_failed,
state.constr_penalty,
state.cg_stop_cond)
state.status = None
if callback is not None and callback(state.x, state):
if callback is not None and callback(np.copy(state.x), state):
state.status = 3
elif state.optimality < gtol and state.constr_violation < gtol:
state.status = 1
Expand Down Expand Up @@ -452,7 +452,7 @@ def stop_criteria(state, x, last_iteration_failed, tr_radius,
state.barrier_parameter,
state.cg_stop_cond)
state.status = None
if callback is not None and callback(state.x, state):
if callback is not None and callback(np.copy(state.x), state):
state.status = 3
elif state.optimality < gtol and state.constr_violation < gtol:
state.status = 1
Expand Down
2 changes: 1 addition & 1 deletion scipy/optimize/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def func_and_grad(x):
# new iteration
n_iterations += 1
if callback is not None:
callback(x)
callback(np.copy(x))

if n_iterations >= maxiter:
task[:] = 'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
Expand Down
3 changes: 2 additions & 1 deletion scipy/optimize/slsqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def fmin_slsqp(func, x0, eqcons=(), f_eqcons=None, ieqcons=(), f_ieqcons=None,
"""
if disp is not None:
iprint = disp

opts = {'maxiter': iter,
'ftol': acc,
'iprint': iprint,
Expand Down Expand Up @@ -426,7 +427,7 @@ def cjac(x, *args):

# call callback if major iteration has incremented
if callback is not None and majiter > majiter_prev:
callback(x)
callback(np.copy(x))

# Print the status of the current iterate if iprint > 2 and the
# major iteration has incremented
Expand Down
61 changes: 61 additions & 0 deletions scipy/optimize/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,67 @@ def dfunc(z):
assert_(func(sol1.x) < func(sol2.x),
"%s: %s vs. %s" % (method, func(sol1.x), func(sol2.x)))

@pytest.mark.parametrize('method', ['fmin', 'fmin_powell', 'fmin_cg', 'fmin_bfgs',
'fmin_ncg', 'fmin_l_bfgs_b', 'fmin_tnc',
'fmin_slsqp',
'Nelder-Mead', 'Powell', 'CG', 'BFGS', 'Newton-CG', 'L-BFGS-B',
'TNC', 'SLSQP', 'trust-constr', 'dogleg', 'trust-ncg',
'trust-exact', 'trust-krylov'])
def test_minimize_callback_copies_array(self, method):
# Check that arrays passed to callbacks are not modified
# inplace by the optimizer afterward

if method in ('fmin_tnc', 'fmin_l_bfgs_b'):
func = lambda x: (optimize.rosen(x), optimize.rosen_der(x))
else:
func = optimize.rosen
jac = optimize.rosen_der
hess = optimize.rosen_hess

x0 = np.zeros(10)

# Set options
kwargs = {}
if method.startswith('fmin'):
routine = getattr(optimize, method)
if method == 'fmin_slsqp':
kwargs['iter'] = 5
elif method == 'fmin_tnc':
kwargs['maxfun'] = 100
else:
kwargs['maxiter'] = 5
else:
def routine(*a, **kw):
kw['method'] = method
return optimize.minimize(*a, **kw)

if method == 'TNC':
kwargs['options'] = dict(maxiter=100)
else:
kwargs['options'] = dict(maxiter=5)

if method in ('fmin_ncg',):
kwargs['fprime'] = jac
elif method in ('Newton-CG',):
kwargs['jac'] = jac
elif method in ('trust-krylov', 'trust-exact', 'trust-ncg', 'dogleg',
'trust-constr'):
kwargs['jac'] = jac
kwargs['hess'] = hess

# Run with callback
results = []

def callback(x, *args, **kwargs):
results.append((x, np.copy(x)))

sol = routine(func, x0, callback=callback, **kwargs)

# Check returned arrays coincide with their copies and have no memory overlap
assert_(len(results) > 2)
assert_(all(np.all(x == y) for x, y in results))
assert_(not any(np.may_share_memory(x[0], y[0]) for x, y in itertools.combinations(results, 2)))

@pytest.mark.parametrize('method', ['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
'l-bfgs-b', 'tnc', 'cobyla', 'slsqp'])
def test_no_increase(self, method):
Expand Down

0 comments on commit ba177cb

Please sign in to comment.