Skip to content

Commit

Permalink
Change DeviceArray.__iter__ and DeviceArray.__reversed__ to forward t…
Browse files Browse the repository at this point in the history
…o the _value.

This has the effect of transferring the entire array to the host and iterating over it in host memory, rather than slicing out individual elements in device memory one by one.

This is much faster for examples like `list(np.arange(10000))`; previously this took several seconds the first time due to compilation and 100ms+ subsequent times. With this change it takes < 1ms.
  • Loading branch information
hawkinsp committed Jul 2, 2019
1 parent 59be9b7 commit 71605f4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,13 +526,13 @@ def __iter__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
return (self[i] for i in xrange(self.shape[0]))
return self._value.__iter__()

def __reversed__(self):
if self.ndim == 0:
raise TypeError("iteration over a 0-d array")
else:
return (self[i] for i in xrange(self.shape[0] - 1, -1, -1))
return reversed(self._value)

def __format__(self, format_spec):
# Simulates behavior of https://github.com/numpy/numpy/pull/9883
Expand Down

0 comments on commit 71605f4

Please sign in to comment.