From e5f1108cc1cc3d348fc87a2dcc550dc145dd5527 Mon Sep 17 00:00:00 2001 From: Stefan van der Walt Date: Sat, 29 Nov 2008 12:41:23 +0000 Subject: [PATCH] In `fminbound`, raise an error if non-scalar bounds are specified [patch by Neil Muller]. Closes #544. --- scipy/optimize/optimize.py | 8 +++++++- scipy/optimize/tests/test_optimize.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/scipy/optimize/optimize.py b/scipy/optimize/optimize.py index 95a2fc7e173f..c6fafcd1c28d 100644 --- a/scipy/optimize/optimize.py +++ b/scipy/optimize/optimize.py @@ -1136,7 +1136,7 @@ def fminbound(func, x1, x2, args=(), xtol=1e-5, maxfun=500, func : callable f(x,*args) Objective function to be minimized (must accept and return scalars). - x1, x2 : ndarray + x1, x2 : float or array scalar The optimization bounds. args : tuple Extra arguments passed to function. @@ -1176,6 +1176,12 @@ def fminbound(func, x1, x2, args=(), xtol=1e-5, maxfun=500, """ + # Test bounds are of correct form + x1 = atleast_1d(x1) + x2 = atleast_1d(x2) + if len(x1) != 1 or len(x2) != 1: + raise ValueError, "Optimisation bounds must be scalars" \ + " or length 1 arrays" if x1 > x2: raise ValueError, "The lower bound exceeds the upper bound." diff --git a/scipy/optimize/tests/test_optimize.py b/scipy/optimize/tests/test_optimize.py index 02383a4de078..f0e1bff045c8 100644 --- a/scipy/optimize/tests/test_optimize.py +++ b/scipy/optimize/tests/test_optimize.py @@ -147,7 +147,21 @@ def test_brent(self): assert max((err1,err2,err3,err4)) < 1e-6 - + def test_fminbound(self): + """Test fminbound + """ + x = optimize.fminbound(lambda x: (x - 1.5)**2 - 0.8, 0, 1) + assert abs(x - 1) < 1e-5 + x = optimize.fminbound(lambda x: (x - 1.5)**2 - 0.8, 1, 5) + assert abs(x - 1.5) < 1e-6 + x = optimize.fminbound(lambda x: (x - 1.5)**2 - 0.8, + numpy.array([1]), numpy.array([5])) + assert abs(x - 1.5) < 1e-6 + assert_raises(ValueError, + optimize.fminbound, lambda x: (x - 1.5)**2 - 0.8, 5, 1) + assert_raises(ValueError, + optimize.fminbound, lambda x: (x - 1.5)**2 - 0.8, + np.zeros(2), 1) class TestTnc(TestCase): """TNC non-linear optimization.