Skip to content

Commit

Permalink
Merge pull request numpy#5302 from idfah/master
Browse files Browse the repository at this point in the history
Fixed meshgrid to return arrays with same dtype as arguments.
  • Loading branch information
rgommers authored Nov 6, 2016
2 parents 9aff656 + ecf11a6 commit e287741
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
18 changes: 7 additions & 11 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4530,18 +4530,14 @@ def meshgrid(*xi, **kwargs):
output[1].shape = (-1, 1) + (1,)*(ndim - 2)
shape[0], shape[1] = shape[1], shape[0]

if sparse:
if copy_:
return [x.copy() for x in output]
else:
return output
else:
if copy_:
output = [x.copy() for x in output]

if not sparse and len(output) > 0:
# Return the full N-D matrix (not only the 1-D vector)
if copy_:
mult_fact = np.ones(shape, dtype=int)
return [x * mult_fact for x in output]
else:
return np.broadcast_arrays(*output)
output = np.broadcast_arrays(*output, subok=True)

return output


def delete(arr, obj, axis=None):
Expand Down
25 changes: 25 additions & 0 deletions numpy/lib/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,7 @@ def test_single_input(self):
def test_no_input(self):
args = []
assert_array_equal([], meshgrid(*args))
assert_array_equal([], meshgrid(*args, copy=False))

def test_indexing(self):
x = [1, 2, 3]
Expand Down Expand Up @@ -2244,6 +2245,30 @@ def test_invalid_arguments(self):
assert_raises(TypeError, meshgrid,
[1, 2, 3], [4, 5, 6, 7], indices='ij')

def test_return_type(self):
# Test for appropriate dtype in returned arrays.
# Regression test for issue #5297
# https://github.com/numpy/numpy/issues/5297
x = np.arange(0, 10, dtype=np.float32)
y = np.arange(10, 20, dtype=np.float64)

X, Y = np.meshgrid(x,y)

assert_(X.dtype == x.dtype)
assert_(Y.dtype == y.dtype)

# copy
X, Y = np.meshgrid(x,y, copy=True)

assert_(X.dtype == x.dtype)
assert_(Y.dtype == y.dtype)

# sparse
X, Y = np.meshgrid(x,y, sparse=True)

assert_(X.dtype == x.dtype)
assert_(Y.dtype == y.dtype)


class TestPiecewise(TestCase):

Expand Down

0 comments on commit e287741

Please sign in to comment.