Skip to content

Commit

Permalink
Merge pull request numpy#3698 from cgohlke/patch-1
Browse files Browse the repository at this point in the history
BUG: check axes and window length input for all integer types in fft.helper
  • Loading branch information
charris committed Sep 8, 2013
2 parents 25445cd + fe05eac commit 73fbfb2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion numpy/compat/py3k.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

__all__ = ['bytes', 'asbytes', 'isfileobj', 'getexception', 'strchar',
'unicode', 'asunicode', 'asbytes_nested', 'asunicode_nested',
'asstr', 'open_latin1', 'long', 'basestring', 'sixu']
'asstr', 'open_latin1', 'long', 'basestring', 'sixu',
'integer_types']

import sys

Expand Down
12 changes: 7 additions & 5 deletions numpy/fft/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
from __future__ import division, absolute_import, print_function

import numpy.core.numerictypes as nt
from numpy.compat import integer_types
from numpy.core import (
asarray, concatenate, arange, take, integer, empty
)
Expand All @@ -13,6 +13,8 @@

__all__ = ['fftshift', 'ifftshift', 'fftfreq', 'rfftfreq']

integer_types = integer_types + (integer,)


def fftshift(x, axes=None):
"""
Expand Down Expand Up @@ -62,7 +64,7 @@ def fftshift(x, axes=None):
ndim = len(tmp.shape)
if axes is None:
axes = list(range(ndim))
elif isinstance(axes, (int, nt.integer)):
elif isinstance(axes, integer_types):
axes = (axes,)
y = tmp
for k in axes:
Expand Down Expand Up @@ -111,7 +113,7 @@ def ifftshift(x, axes=None):
ndim = len(tmp.shape)
if axes is None:
axes = list(range(ndim))
elif isinstance(axes, (int, nt.integer)):
elif isinstance(axes, integer_types):
axes = (axes,)
y = tmp
for k in axes:
Expand Down Expand Up @@ -158,7 +160,7 @@ def fftfreq(n, d=1.0):
array([ 0. , 1.25, 2.5 , 3.75, -5. , -3.75, -2.5 , -1.25])
"""
if not (isinstance(n, int) or isinstance(n, integer)):
if not isinstance(n, integer_types):
raise ValueError("n should be an integer")
val = 1.0 / (n * d)
results = empty(n, int)
Expand Down Expand Up @@ -214,7 +216,7 @@ def rfftfreq(n, d=1.0):
array([ 0., 10., 20., 30., 40., 50.])
"""
if not (isinstance(n, int) or isinstance(n, integer)):
if not isinstance(n, integer_types):
raise ValueError("n should be an integer")
val = 1.0/(n*d)
N = n//2 + 1
Expand Down

0 comments on commit 73fbfb2

Please sign in to comment.