Skip to content

Commit

Permalink
Merge pull request scipy#8183 from endolith/0d_conj_correlate
Browse files Browse the repository at this point in the history
FIX: 0d conj correlate
  • Loading branch information
ilayn authored Dec 7, 2017
2 parents 70e61de + 42a8885 commit fbb69de
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scipy/signal/signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def correlate(in1, in2, mode='full', method='auto'):
in2 = asarray(in2)

if in1.ndim == in2.ndim == 0:
return in1 * in2
return in1 * in2.conj()
elif in1.ndim != in2.ndim:
raise ValueError("in1 and in2 should have the same dimensionality")

Expand Down Expand Up @@ -1112,7 +1112,7 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):

val = _valfrommode(mode)
bval = _bvalfromboundary(boundary)
out = sigtools._convolve2d(in1, in2, 0, val, bval, fillvalue)
out = sigtools._convolve2d(in1, in2.conj(), 0, val, bval, fillvalue)

if swapped_inputs:
out = out[::-1, ::-1]
Expand Down
23 changes: 23 additions & 0 deletions scipy/signal/tests/test_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,24 @@ def test_rank3(self, dt):
assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
assert_equal(y.dtype, dt)

def test_rank0(self, dt):
a = np.array(np.random.randn()).astype(dt)
a += 1j * np.array(np.random.randn()).astype(dt)
b = np.array(np.random.randn()).astype(dt)
b += 1j * np.array(np.random.randn()).astype(dt)

y_r = (correlate(a.real, b.real)
+ correlate(a.imag, b.imag)).astype(dt)
y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))

y = correlate(a, b, 'full')
assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
assert_equal(y.dtype, dt)

assert_equal(correlate([1], [2j]), correlate(1, 2j))
assert_equal(correlate([2j], [3j]), correlate(2j, 3j))
assert_equal(correlate([3j], [4]), correlate(3j, 4))


class TestCorrelate2d(object):

Expand Down Expand Up @@ -1419,6 +1437,11 @@ def test_invalid_shapes(self):
assert_raises(ValueError, signal.correlate2d, *(a, b), **{'mode': 'valid'})
assert_raises(ValueError, signal.correlate2d, *(b, a), **{'mode': 'valid'})

def test_complex_input(self):
assert_equal(signal.correlate2d([[1]], [[2j]]), -2j)
assert_equal(signal.correlate2d([[2j]], [[3j]]), 6)
assert_equal(signal.correlate2d([[3j]], [[4]]), 12j)


class TestLFilterZI(object):

Expand Down

0 comments on commit fbb69de

Please sign in to comment.