Skip to content

Commit

Permalink
fixes bug #38; added tests and updated existing ones
Browse files Browse the repository at this point in the history
  • Loading branch information
terhorst committed Jun 1, 2011
1 parent e60a615 commit 00db88c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 42 deletions.
63 changes: 34 additions & 29 deletions datarray/datarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ class AxesManager(object):
An axis can be indexed by integers or ticks:
>>> np.all(A.axes.stocks['aapl':'goog'] == A.axes.stocks[0:2])
DataArray(True, dtype=bool)
('date', 'stocks', 'metric')
DataArray(array(True, dtype=bool),
('date', ('stocks', ('aapl', 'ibm')), 'metric'))
>>> np.all(A.axes.stocks[0:2] == A[:,0:2,:])
DataArray(True, dtype=bool)
('date', 'stocks', 'metric')
DataArray(array(True, dtype=bool),
('date', ('stocks', ('aapl', 'ibm')), 'metric'))
Axes can also be accessed numerically:
Expand All @@ -102,8 +104,8 @@ class AxesManager(object):
>>> Ai = A.axes('stocks', 'date')
>>> np.all(Ai['aapl':'goog', 100] == A[100, 0:2])
DataArray(True, dtype=bool)
('stocks', 'metric')
DataArray(array(True, dtype=bool),
(('stocks', ('aapl', 'ibm')), 'metric'))
You can also mix axis names and integers when calling AxesManager.
(Not yet supported.)
Expand Down Expand Up @@ -132,11 +134,9 @@ def __getattribute__(self, name):
def __len__(self):
return len(object.__getattribute__(self, '_axes'))

def __str__(self):
def __repr__(self):
return str(tuple(self))

__repr__ = __str__

def __getitem__(self, n):
"""Return the `n`th axis object of the array.
Expand Down Expand Up @@ -387,12 +387,10 @@ def __eq__(self, other):
return self.name == other.name and self.index == other.index and \
self.labels == other.labels

def __str__(self):
def __repr__(self):
return 'Axis(name=%r, index=%i, labels=%r)' % \
(self.name, self.index, self.labels)

__repr__ = __str__

def __getitem__(self, key):
"""
Return the item(s) of parent array along this axis as specified by `key`.
Expand All @@ -409,15 +407,19 @@ def __getitem__(self, key):
>>> A = DataArray(np.arange(2*3*2).reshape([2,3,2]), \
('a', ('b', ('b1','b2','b3')), 'c'))
>>> b = A.axes.b
>>> np.all(b['b1'] == A[:,0,:])
DataArray(True, dtype=bool)
('a', 'c')
>>> np.all(b['b1':'b2'] == A[:,0:1,:])
DataArray(True, dtype=bool)
('a', 'b', 'c')
DataArray(array(True, dtype=bool),
('a', 'c'))
>>> np.all(b['b2':] == A[:,1:,:])
DataArray(True, dtype=bool)
('a', 'b', 'c')
DataArray(array(True, dtype=bool),
('a', ('b', ('b2', 'b3')), 'c'))
>>> np.all(b['b1':'b2'] == A[:,0:1,:])
DataArray(array(True, dtype=bool),
('a', ('b', ('b1',)), 'c'))
"""
# XXX We don't handle fancy indexing at the moment
if isinstance(key, (np.ndarray, list)):
Expand Down Expand Up @@ -599,8 +601,8 @@ def drop(self, labels):
>>> arr1 = darr.axes.b.keep(['c','d'])
>>> arr2 = darr.axes.b.drop(['a','b','e'])
>>> np.all(arr1 == arr2)
DataArray(True, dtype=bool)
('a', 'b')
DataArray(array(True, dtype=bool),
('a', ('b', ('c', 'd'))))
"""

if not self.labels:
Expand Down Expand Up @@ -735,7 +737,6 @@ def runs_op(*args, **kwargs):
def is_numpy_scalar(arr):
return arr.ndim == 0


def _apply_accumulation(opname, kwnames):
super_op = getattr(np.ndarray, opname)
if 'axis' not in kwnames:
Expand Down Expand Up @@ -763,7 +764,6 @@ def runs_op(*args, **kwargs):
return runs_op

class DataArray(np.ndarray):

# XXX- we need to figure out where in the numpy C code .T is defined!
@property
def T(self):
Expand Down Expand Up @@ -1003,15 +1003,20 @@ def __getitem__(self, key):

return arr

def __str_repr_helper(self, ary_repr):
"""Helper function for __str__ and __repr__. Produce a text
representation of the axis suitable for eval() as an argument to a
DataArray constructor."""
axis_spec = repr(tuple(ax.name if ax.labels is None
else (ax.name, tuple(ax.labels)) for ax in self.axes))
return "%s(%s,\n%s)" % \
(self.__class__.__name__, ary_repr, axis_spec)

def __str__(self):
s = super(DataArray, self).__str__()
s = '\n'.join([s, str(self.names)])
return s
return self.__str_repr_helper(np.asarray(self).__str__())

def __repr__(self):
s = super(DataArray, self).__repr__()
s = '\n'.join([s, str(self.names)])
return s
return self.__str_repr_helper(np.asarray(self).__repr__())

# Methods from ndarray

Expand Down
36 changes: 23 additions & 13 deletions datarray/tests/test_bugfixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ def test_bug26():
a.axes[0].name = "a"
nt.assert_equal(a.axes[0].name, "a")

def test_bug35():
"Bug 35"
txt_array = DataArray(['a','b'], axes=['dummy'])
#calling datarray_to_string on string arrays used to fail
print_grid.datarray_to_string(txt_array)
#because get_formatter returned the class not an instance
assert isinstance(print_grid.get_formatter(txt_array),
print_grid.StrFormatter)

def test_bug38():
"Bug 38: DataArray.__repr__ should parse as a single entity"
# Calling repr() on an ndarray prepends array (instead of np.array)
array = np.array
arys = (
DataArray(np.random.randint(0, 10000, size=(1,2,3,4,5)), 'abcde'),
DataArray(np.random.randint(0, 10000, size=(3,3,3))), # Try with missing axes
DataArray(np.random.randint(0, 10000, (2,4,5,6)), # Try with ticks
('a', ('b', ('b1','b2','b3','b4')), 'c', 'd')),
)
for A in arys:
print A
assert_datarray_equal(A, eval(repr(A)))

def test_bug44():
"Bug 44"
# In instances where axis=None, the operation runs
Expand All @@ -43,16 +66,3 @@ def test_bug44():
y = np.std(A)
nt.assert_equal( x.sum(), y.sum() )

def test_bug45():
"Bug 45: Support for np.outer()"
A = DataArray([1,2,3], 'a'); B = DataArray([2,3,4], 'b'); C = np.outer(A,B)
assert_datarray_equal(C,DataArray(C, 'ab'))

def test_bug35():
"Bug 35"
txt_array = DataArray(['a','b'], axes=['dummy'])
#calling datarray_to_string on string arrays used to fail
print_grid.datarray_to_string(txt_array)
#because get_formatter returned the class not an instance
assert isinstance(print_grid.get_formatter(txt_array),
print_grid.StrFormatter)

0 comments on commit 00db88c

Please sign in to comment.